diff --git a/libnetwork/drivers/bridge/setup_ip_tables.go b/libnetwork/drivers/bridge/setup_ip_tables.go index 392965803b..c2ef682e3b 100644 --- a/libnetwork/drivers/bridge/setup_ip_tables.go +++ b/libnetwork/drivers/bridge/setup_ip_tables.go @@ -4,9 +4,9 @@ import ( "fmt" "net" - "github.com/docker/docker/daemon/networkdriver" - "github.com/docker/docker/daemon/networkdriver/portmapper" "github.com/docker/docker/pkg/iptables" + "github.com/docker/libnetwork" + "github.com/docker/libnetwork/portmapper" ) // DockerChain: DOCKER iptable chain name @@ -20,7 +20,7 @@ func setupIPTables(i *bridgeInterface) error { return fmt.Errorf("Unexpected request to set IP tables for interface: %s", i.Config.BridgeName) } - addrv4, _, err := networkdriver.GetIfaceAddr(i.Config.BridgeName) + addrv4, _, err := libnetwork.GetIfaceAddr(i.Config.BridgeName) if err != nil { return fmt.Errorf("Failed to setup IP tables, cannot acquire Interface address: %s", err.Error()) } diff --git a/libnetwork/portallocator/portallocator.go b/libnetwork/portallocator/portallocator.go new file mode 100644 index 0000000000..ed913ddc7e --- /dev/null +++ b/libnetwork/portallocator/portallocator.go @@ -0,0 +1,162 @@ +package portallocator + +import ( + "errors" + "fmt" + "net" + "sync" +) + +type portMap struct { + p map[int]struct{} + last int +} + +func newPortMap() *portMap { + return &portMap{ + p: map[int]struct{}{}, + last: EndPortRange, + } +} + +type protoMap map[string]*portMap + +func newProtoMap() protoMap { + return protoMap{ + "tcp": newPortMap(), + "udp": newPortMap(), + } +} + +type ipMapping map[string]protoMap + +const ( + // BeginPortRange indicates the first port in port range + BeginPortRange = 49153 + // EndPortRange indicates the last port in port range + EndPortRange = 65535 +) + +var ( + // ErrAllPortsAllocated is returned when no more ports are available + ErrAllPortsAllocated = errors.New("all ports are allocated") + // ErrUnknownProtocol is returned when an unknown protocol was specified + ErrUnknownProtocol = errors.New("unknown protocol") +) + +var ( + mutex sync.Mutex + + defaultIP = net.ParseIP("0.0.0.0") + globalMap = ipMapping{} +) + +// ErrPortAlreadyAllocated is the returned error information when a requested port is already being used +type ErrPortAlreadyAllocated struct { + ip string + port int +} + +func newErrPortAlreadyAllocated(ip string, port int) ErrPortAlreadyAllocated { + return ErrPortAlreadyAllocated{ + ip: ip, + port: port, + } +} + +// IP returns the address to which the used port is associated +func (e ErrPortAlreadyAllocated) IP() string { + return e.ip +} + +// Port returns the value of the already used port +func (e ErrPortAlreadyAllocated) Port() int { + return e.port +} + +// IPPort returns the address and the port in the form ip:port +func (e ErrPortAlreadyAllocated) IPPort() string { + return fmt.Sprintf("%s:%d", e.ip, e.port) +} + +// Error is the implementation of error.Error interface +func (e ErrPortAlreadyAllocated) Error() string { + return fmt.Sprintf("Bind for %s:%d failed: port is already allocated", e.ip, e.port) +} + +// RequestPort requests new port from global ports pool for specified ip and proto. +// If port is 0 it returns first free port. Otherwise it cheks port availability +// in pool and return that port or error if port is already busy. +func RequestPort(ip net.IP, proto string, port int) (int, error) { + mutex.Lock() + defer mutex.Unlock() + + if proto != "tcp" && proto != "udp" { + return 0, ErrUnknownProtocol + } + + if ip == nil { + ip = defaultIP + } + ipstr := ip.String() + protomap, ok := globalMap[ipstr] + if !ok { + protomap = newProtoMap() + globalMap[ipstr] = protomap + } + mapping := protomap[proto] + if port > 0 { + if _, ok := mapping.p[port]; !ok { + mapping.p[port] = struct{}{} + return port, nil + } + return 0, newErrPortAlreadyAllocated(ipstr, port) + } + + port, err := mapping.findPort() + if err != nil { + return 0, err + } + return port, nil +} + +// ReleasePort releases port from global ports pool for specified ip and proto. +func ReleasePort(ip net.IP, proto string, port int) error { + mutex.Lock() + defer mutex.Unlock() + + if ip == nil { + ip = defaultIP + } + protomap, ok := globalMap[ip.String()] + if !ok { + return nil + } + delete(protomap[proto].p, port) + return nil +} + +// ReleaseAll releases all ports for all ips. +func ReleaseAll() error { + mutex.Lock() + globalMap = ipMapping{} + mutex.Unlock() + return nil +} + +func (pm *portMap) findPort() (int, error) { + port := pm.last + for i := 0; i <= EndPortRange-BeginPortRange; i++ { + port++ + if port > EndPortRange { + port = BeginPortRange + } + + if _, ok := pm.p[port]; !ok { + pm.p[port] = struct{}{} + pm.last = port + return port, nil + } + } + return 0, ErrAllPortsAllocated +} diff --git a/libnetwork/portallocator/portallocator_test.go b/libnetwork/portallocator/portallocator_test.go new file mode 100644 index 0000000000..72581f1040 --- /dev/null +++ b/libnetwork/portallocator/portallocator_test.go @@ -0,0 +1,245 @@ +package portallocator + +import ( + "net" + "testing" +) + +func reset() { + ReleaseAll() +} + +func TestRequestNewPort(t *testing.T) { + defer reset() + + port, err := RequestPort(defaultIP, "tcp", 0) + if err != nil { + t.Fatal(err) + } + + if expected := BeginPortRange; port != expected { + t.Fatalf("Expected port %d got %d", expected, port) + } +} + +func TestRequestSpecificPort(t *testing.T) { + defer reset() + + port, err := RequestPort(defaultIP, "tcp", 5000) + if err != nil { + t.Fatal(err) + } + if port != 5000 { + t.Fatalf("Expected port 5000 got %d", port) + } +} + +func TestReleasePort(t *testing.T) { + defer reset() + + port, err := RequestPort(defaultIP, "tcp", 5000) + if err != nil { + t.Fatal(err) + } + if port != 5000 { + t.Fatalf("Expected port 5000 got %d", port) + } + + if err := ReleasePort(defaultIP, "tcp", 5000); err != nil { + t.Fatal(err) + } +} + +func TestReuseReleasedPort(t *testing.T) { + defer reset() + + port, err := RequestPort(defaultIP, "tcp", 5000) + if err != nil { + t.Fatal(err) + } + if port != 5000 { + t.Fatalf("Expected port 5000 got %d", port) + } + + if err := ReleasePort(defaultIP, "tcp", 5000); err != nil { + t.Fatal(err) + } + + port, err = RequestPort(defaultIP, "tcp", 5000) + if err != nil { + t.Fatal(err) + } +} + +func TestReleaseUnreadledPort(t *testing.T) { + defer reset() + + port, err := RequestPort(defaultIP, "tcp", 5000) + if err != nil { + t.Fatal(err) + } + if port != 5000 { + t.Fatalf("Expected port 5000 got %d", port) + } + + port, err = RequestPort(defaultIP, "tcp", 5000) + + switch err.(type) { + case ErrPortAlreadyAllocated: + default: + t.Fatalf("Expected port allocation error got %s", err) + } +} + +func TestUnknowProtocol(t *testing.T) { + defer reset() + + if _, err := RequestPort(defaultIP, "tcpp", 0); err != ErrUnknownProtocol { + t.Fatalf("Expected error %s got %s", ErrUnknownProtocol, err) + } +} + +func TestAllocateAllPorts(t *testing.T) { + defer reset() + + for i := 0; i <= EndPortRange-BeginPortRange; i++ { + port, err := RequestPort(defaultIP, "tcp", 0) + if err != nil { + t.Fatal(err) + } + + if expected := BeginPortRange + i; port != expected { + t.Fatalf("Expected port %d got %d", expected, port) + } + } + + if _, err := RequestPort(defaultIP, "tcp", 0); err != ErrAllPortsAllocated { + t.Fatalf("Expected error %s got %s", ErrAllPortsAllocated, err) + } + + _, err := RequestPort(defaultIP, "udp", 0) + if err != nil { + t.Fatal(err) + } + + // release a port in the middle and ensure we get another tcp port + port := BeginPortRange + 5 + if err := ReleasePort(defaultIP, "tcp", port); err != nil { + t.Fatal(err) + } + newPort, err := RequestPort(defaultIP, "tcp", 0) + if err != nil { + t.Fatal(err) + } + if newPort != port { + t.Fatalf("Expected port %d got %d", port, newPort) + } + + // now pm.last == newPort, release it so that it's the only free port of + // the range, and ensure we get it back + if err := ReleasePort(defaultIP, "tcp", newPort); err != nil { + t.Fatal(err) + } + port, err = RequestPort(defaultIP, "tcp", 0) + if err != nil { + t.Fatal(err) + } + if newPort != port { + t.Fatalf("Expected port %d got %d", newPort, port) + } +} + +func BenchmarkAllocatePorts(b *testing.B) { + defer reset() + + for i := 0; i < b.N; i++ { + for i := 0; i <= EndPortRange-BeginPortRange; i++ { + port, err := RequestPort(defaultIP, "tcp", 0) + if err != nil { + b.Fatal(err) + } + + if expected := BeginPortRange + i; port != expected { + b.Fatalf("Expected port %d got %d", expected, port) + } + } + reset() + } +} + +func TestPortAllocation(t *testing.T) { + defer reset() + + ip := net.ParseIP("192.168.0.1") + ip2 := net.ParseIP("192.168.0.2") + if port, err := RequestPort(ip, "tcp", 80); err != nil { + t.Fatal(err) + } else if port != 80 { + t.Fatalf("Acquire(80) should return 80, not %d", port) + } + port, err := RequestPort(ip, "tcp", 0) + if err != nil { + t.Fatal(err) + } + if port <= 0 { + t.Fatalf("Acquire(0) should return a non-zero port") + } + + if _, err := RequestPort(ip, "tcp", port); err == nil { + t.Fatalf("Acquiring a port already in use should return an error") + } + + if newPort, err := RequestPort(ip, "tcp", 0); err != nil { + t.Fatal(err) + } else if newPort == port { + t.Fatalf("Acquire(0) allocated the same port twice: %d", port) + } + + if _, err := RequestPort(ip, "tcp", 80); err == nil { + t.Fatalf("Acquiring a port already in use should return an error") + } + if _, err := RequestPort(ip2, "tcp", 80); err != nil { + t.Fatalf("It should be possible to allocate the same port on a different interface") + } + if _, err := RequestPort(ip2, "tcp", 80); err == nil { + t.Fatalf("Acquiring a port already in use should return an error") + } + if err := ReleasePort(ip, "tcp", 80); err != nil { + t.Fatal(err) + } + if _, err := RequestPort(ip, "tcp", 80); err != nil { + t.Fatal(err) + } + + port, err = RequestPort(ip, "tcp", 0) + if err != nil { + t.Fatal(err) + } + port2, err := RequestPort(ip, "tcp", port+1) + if err != nil { + t.Fatal(err) + } + port3, err := RequestPort(ip, "tcp", 0) + if err != nil { + t.Fatal(err) + } + if port3 == port2 { + t.Fatal("Requesting a dynamic port should never allocate a used port") + } +} + +func TestNoDuplicateBPR(t *testing.T) { + defer reset() + + if port, err := RequestPort(defaultIP, "tcp", BeginPortRange); err != nil { + t.Fatal(err) + } else if port != BeginPortRange { + t.Fatalf("Expected port %d got %d", BeginPortRange, port) + } + + if port, err := RequestPort(defaultIP, "tcp", 0); err != nil { + t.Fatal(err) + } else if port == BeginPortRange { + t.Fatalf("Acquire(0) allocated the same port twice: %d", port) + } +} diff --git a/libnetwork/portmapper/mapper.go b/libnetwork/portmapper/mapper.go new file mode 100644 index 0000000000..3064eccd94 --- /dev/null +++ b/libnetwork/portmapper/mapper.go @@ -0,0 +1,182 @@ +package portmapper + +import ( + "errors" + "fmt" + "net" + "sync" + + log "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/iptables" + "github.com/docker/libnetwork/portallocator" +) + +type mapping struct { + proto string + userlandProxy userlandProxy + host net.Addr + container net.Addr +} + +var ( + chain *iptables.Chain + lock sync.Mutex + + // udp:ip:port + currentMappings = make(map[string]*mapping) + + newProxy = newProxyCommand +) + +var ( + // ErrUnknownBackendAddressType refers to an unknown container or unsupported address type + ErrUnknownBackendAddressType = errors.New("unknown container address type not supported") + // ErrPortMappedForIP refers to a port already mapped to an ip address + ErrPortMappedForIP = errors.New("port is already mapped to ip") + // ErrPortNotMapped refers to an unampped port + ErrPortNotMapped = errors.New("port is not mapped") +) + +// SetIptablesChain sets the specified chain into portmapper +func SetIptablesChain(c *iptables.Chain) { + chain = c +} + +// Map maps the specified container transport address to the host's network address and transport port +func Map(container net.Addr, hostIP net.IP, hostPort int) (host net.Addr, err error) { + lock.Lock() + defer lock.Unlock() + + var ( + m *mapping + proto string + allocatedHostPort int + proxy userlandProxy + ) + + switch container.(type) { + case *net.TCPAddr: + proto = "tcp" + if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil { + return nil, err + } + + m = &mapping{ + proto: proto, + host: &net.TCPAddr{IP: hostIP, Port: allocatedHostPort}, + container: container, + } + + proxy = newProxy(proto, hostIP, allocatedHostPort, container.(*net.TCPAddr).IP, container.(*net.TCPAddr).Port) + case *net.UDPAddr: + proto = "udp" + if allocatedHostPort, err = portallocator.RequestPort(hostIP, proto, hostPort); err != nil { + return nil, err + } + + m = &mapping{ + proto: proto, + host: &net.UDPAddr{IP: hostIP, Port: allocatedHostPort}, + container: container, + } + + proxy = newProxy(proto, hostIP, allocatedHostPort, container.(*net.UDPAddr).IP, container.(*net.UDPAddr).Port) + default: + return nil, ErrUnknownBackendAddressType + } + + // release the allocated port on any further error during return. + defer func() { + if err != nil { + portallocator.ReleasePort(hostIP, proto, allocatedHostPort) + } + }() + + key := getKey(m.host) + if _, exists := currentMappings[key]; exists { + return nil, ErrPortMappedForIP + } + + containerIP, containerPort := getIPAndPort(m.container) + if err := forward(iptables.Append, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort); err != nil { + return nil, err + } + + cleanup := func() error { + // need to undo the iptables rules before we return + proxy.Stop() + forward(iptables.Delete, m.proto, hostIP, allocatedHostPort, containerIP.String(), containerPort) + if err := portallocator.ReleasePort(hostIP, m.proto, allocatedHostPort); err != nil { + return err + } + + return nil + } + + if err := proxy.Start(); err != nil { + if err := cleanup(); err != nil { + return nil, fmt.Errorf("Error during port allocation cleanup: %v", err) + } + return nil, err + } + m.userlandProxy = proxy + currentMappings[key] = m + return m.host, nil +} + +// Unmap removes stored mapping for the specified host transport address +func Unmap(host net.Addr) error { + lock.Lock() + defer lock.Unlock() + + key := getKey(host) + data, exists := currentMappings[key] + if !exists { + return ErrPortNotMapped + } + + data.userlandProxy.Stop() + + delete(currentMappings, key) + + containerIP, containerPort := getIPAndPort(data.container) + hostIP, hostPort := getIPAndPort(data.host) + if err := forward(iptables.Delete, data.proto, hostIP, hostPort, containerIP.String(), containerPort); err != nil { + log.Errorf("Error on iptables delete: %s", err) + } + + switch a := host.(type) { + case *net.TCPAddr: + return portallocator.ReleasePort(a.IP, "tcp", a.Port) + case *net.UDPAddr: + return portallocator.ReleasePort(a.IP, "udp", a.Port) + } + return nil +} + +func getKey(a net.Addr) string { + switch t := a.(type) { + case *net.TCPAddr: + return fmt.Sprintf("%s:%d/%s", t.IP.String(), t.Port, "tcp") + case *net.UDPAddr: + return fmt.Sprintf("%s:%d/%s", t.IP.String(), t.Port, "udp") + } + return "" +} + +func getIPAndPort(a net.Addr) (net.IP, int) { + switch t := a.(type) { + case *net.TCPAddr: + return t.IP, t.Port + case *net.UDPAddr: + return t.IP, t.Port + } + return nil, 0 +} + +func forward(action iptables.Action, proto string, sourceIP net.IP, sourcePort int, containerIP string, containerPort int) error { + if chain == nil { + return nil + } + return chain.Forward(action, sourceIP, sourcePort, proto, containerIP, containerPort) +} diff --git a/libnetwork/portmapper/mapper_test.go b/libnetwork/portmapper/mapper_test.go new file mode 100644 index 0000000000..257916ba21 --- /dev/null +++ b/libnetwork/portmapper/mapper_test.go @@ -0,0 +1,152 @@ +package portmapper + +import ( + "net" + "testing" + + "github.com/docker/docker/pkg/iptables" + "github.com/docker/libnetwork/portallocator" +) + +func init() { + // override this func to mock out the proxy server + newProxy = newMockProxyCommand +} + +func reset() { + chain = nil + currentMappings = make(map[string]*mapping) +} + +func TestSetIptablesChain(t *testing.T) { + defer reset() + + c := &iptables.Chain{ + Name: "TEST", + Bridge: "192.168.1.1", + } + + if chain != nil { + t.Fatal("chain should be nil at init") + } + + SetIptablesChain(c) + if chain == nil { + t.Fatal("chain should not be nil after set") + } +} + +func TestMapPorts(t *testing.T) { + dstIP1 := net.ParseIP("192.168.0.1") + dstIP2 := net.ParseIP("192.168.0.2") + dstAddr1 := &net.TCPAddr{IP: dstIP1, Port: 80} + dstAddr2 := &net.TCPAddr{IP: dstIP2, Port: 80} + + srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")} + srcAddr2 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.2")} + + addrEqual := func(addr1, addr2 net.Addr) bool { + return (addr1.Network() == addr2.Network()) && (addr1.String() == addr2.String()) + } + + if host, err := Map(srcAddr1, dstIP1, 80); err != nil { + t.Fatalf("Failed to allocate port: %s", err) + } else if !addrEqual(dstAddr1, host) { + t.Fatalf("Incorrect mapping result: expected %s:%s, got %s:%s", + dstAddr1.String(), dstAddr1.Network(), host.String(), host.Network()) + } + + if _, err := Map(srcAddr1, dstIP1, 80); err == nil { + t.Fatalf("Port is in use - mapping should have failed") + } + + if _, err := Map(srcAddr2, dstIP1, 80); err == nil { + t.Fatalf("Port is in use - mapping should have failed") + } + + if _, err := Map(srcAddr2, dstIP2, 80); err != nil { + t.Fatalf("Failed to allocate port: %s", err) + } + + if Unmap(dstAddr1) != nil { + t.Fatalf("Failed to release port") + } + + if Unmap(dstAddr2) != nil { + t.Fatalf("Failed to release port") + } + + if Unmap(dstAddr2) == nil { + t.Fatalf("Port already released, but no error reported") + } +} + +func TestGetUDPKey(t *testing.T) { + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.5"), Port: 53} + + key := getKey(addr) + + if expected := "192.168.1.5:53/udp"; key != expected { + t.Fatalf("expected key %s got %s", expected, key) + } +} + +func TestGetTCPKey(t *testing.T) { + addr := &net.TCPAddr{IP: net.ParseIP("192.168.1.5"), Port: 80} + + key := getKey(addr) + + if expected := "192.168.1.5:80/tcp"; key != expected { + t.Fatalf("expected key %s got %s", expected, key) + } +} + +func TestGetUDPIPAndPort(t *testing.T) { + addr := &net.UDPAddr{IP: net.ParseIP("192.168.1.5"), Port: 53} + + ip, port := getIPAndPort(addr) + if expected := "192.168.1.5"; ip.String() != expected { + t.Fatalf("expected ip %s got %s", expected, ip) + } + + if ep := 53; port != ep { + t.Fatalf("expected port %d got %d", ep, port) + } +} + +func TestMapAllPortsSingleInterface(t *testing.T) { + dstIP1 := net.ParseIP("0.0.0.0") + srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")} + + hosts := []net.Addr{} + var host net.Addr + var err error + + defer func() { + for _, val := range hosts { + Unmap(val) + } + }() + + for i := 0; i < 10; i++ { + for i := portallocator.BeginPortRange; i < portallocator.EndPortRange; i++ { + if host, err = Map(srcAddr1, dstIP1, 0); err != nil { + t.Fatal(err) + } + + hosts = append(hosts, host) + } + + if _, err := Map(srcAddr1, dstIP1, portallocator.BeginPortRange); err == nil { + t.Fatalf("Port %d should be bound but is not", portallocator.BeginPortRange) + } + + for _, val := range hosts { + if err := Unmap(val); err != nil { + t.Fatal(err) + } + } + + hosts = []net.Addr{} + } +} diff --git a/libnetwork/portmapper/mock_proxy.go b/libnetwork/portmapper/mock_proxy.go new file mode 100644 index 0000000000..29b1605889 --- /dev/null +++ b/libnetwork/portmapper/mock_proxy.go @@ -0,0 +1,18 @@ +package portmapper + +import "net" + +func newMockProxyCommand(proto string, hostIP net.IP, hostPort int, containerIP net.IP, containerPort int) userlandProxy { + return &mockProxyCommand{} +} + +type mockProxyCommand struct { +} + +func (p *mockProxyCommand) Start() error { + return nil +} + +func (p *mockProxyCommand) Stop() error { + return nil +} diff --git a/libnetwork/portmapper/proxy.go b/libnetwork/portmapper/proxy.go new file mode 100644 index 0000000000..5cbb4dc2a8 --- /dev/null +++ b/libnetwork/portmapper/proxy.go @@ -0,0 +1,161 @@ +package portmapper + +import ( + "flag" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "os/signal" + "strconv" + "syscall" + "time" + + "github.com/docker/docker/pkg/proxy" + "github.com/docker/docker/pkg/reexec" +) + +const userlandProxyCommandName = "docker-proxy" + +func init() { + reexec.Register(userlandProxyCommandName, execProxy) +} + +type userlandProxy interface { + Start() error + Stop() error +} + +// proxyCommand wraps an exec.Cmd to run the userland TCP and UDP +// proxies as separate processes. +type proxyCommand struct { + cmd *exec.Cmd +} + +// execProxy is the reexec function that is registered to start the userland proxies +func execProxy() { + f := os.NewFile(3, "signal-parent") + host, container := parseHostContainerAddrs() + + p, err := proxy.NewProxy(host, container) + if err != nil { + fmt.Fprintf(f, "1\n%s", err) + f.Close() + os.Exit(1) + } + go handleStopSignals(p) + fmt.Fprint(f, "0\n") + f.Close() + + // Run will block until the proxy stops + p.Run() +} + +// parseHostContainerAddrs parses the flags passed on reexec to create the TCP or UDP +// net.Addrs to map the host and container ports +func parseHostContainerAddrs() (host net.Addr, container net.Addr) { + var ( + proto = flag.String("proto", "tcp", "proxy protocol") + hostIP = flag.String("host-ip", "", "host ip") + hostPort = flag.Int("host-port", -1, "host port") + containerIP = flag.String("container-ip", "", "container ip") + containerPort = flag.Int("container-port", -1, "container port") + ) + + flag.Parse() + + switch *proto { + case "tcp": + host = &net.TCPAddr{IP: net.ParseIP(*hostIP), Port: *hostPort} + container = &net.TCPAddr{IP: net.ParseIP(*containerIP), Port: *containerPort} + case "udp": + host = &net.UDPAddr{IP: net.ParseIP(*hostIP), Port: *hostPort} + container = &net.UDPAddr{IP: net.ParseIP(*containerIP), Port: *containerPort} + default: + log.Fatalf("unsupported protocol %s", *proto) + } + + return host, container +} + +func handleStopSignals(p proxy.Proxy) { + s := make(chan os.Signal, 10) + signal.Notify(s, os.Interrupt, syscall.SIGTERM, syscall.SIGSTOP) + + for _ = range s { + p.Close() + + os.Exit(0) + } +} + +func newProxyCommand(proto string, hostIP net.IP, hostPort int, containerIP net.IP, containerPort int) userlandProxy { + args := []string{ + userlandProxyCommandName, + "-proto", proto, + "-host-ip", hostIP.String(), + "-host-port", strconv.Itoa(hostPort), + "-container-ip", containerIP.String(), + "-container-port", strconv.Itoa(containerPort), + } + + return &proxyCommand{ + cmd: &exec.Cmd{ + Path: reexec.Self(), + Args: args, + SysProcAttr: &syscall.SysProcAttr{ + Pdeathsig: syscall.SIGTERM, // send a sigterm to the proxy if the daemon process dies + }, + }, + } +} + +func (p *proxyCommand) Start() error { + r, w, err := os.Pipe() + if err != nil { + return fmt.Errorf("proxy unable to open os.Pipe %s", err) + } + defer r.Close() + p.cmd.ExtraFiles = []*os.File{w} + if err := p.cmd.Start(); err != nil { + return err + } + w.Close() + + errchan := make(chan error, 1) + go func() { + buf := make([]byte, 2) + r.Read(buf) + + if string(buf) != "0\n" { + errStr, err := ioutil.ReadAll(r) + if err != nil { + errchan <- fmt.Errorf("Error reading exit status from userland proxy: %v", err) + return + } + + errchan <- fmt.Errorf("Error starting userland proxy: %s", errStr) + return + } + errchan <- nil + }() + + select { + case err := <-errchan: + return err + case <-time.After(16 * time.Second): + return fmt.Errorf("Timed out proxy starting the userland proxy") + } +} + +func (p *proxyCommand) Stop() error { + if p.cmd.Process != nil { + if err := p.cmd.Process.Signal(os.Interrupt); err != nil { + return err + } + return p.cmd.Wait() + } + return nil +}