diff --git a/network.go b/network.go index 1397de0557..fc50a5a907 100644 --- a/network.go +++ b/network.go @@ -225,16 +225,22 @@ func getIfaceAddr(name string) (net.Addr, error) { // up iptables rules. // It keeps track of all mappings and is able to unmap at will type PortMapper struct { - tcpMapping map[int]*net.TCPAddr - tcpProxies map[int]proxy.Proxy - udpMapping map[int]*net.UDPAddr - udpProxies map[int]proxy.Proxy + tcpMapping map[string]*net.TCPAddr + tcpProxies map[string]proxy.Proxy + udpMapping map[string]*net.UDPAddr + udpProxies map[string]proxy.Proxy - iptables *iptables.Chain - defaultIp net.IP + iptables *iptables.Chain + defaultIp net.IP + proxyFactoryFunc func(net.Addr, net.Addr) (proxy.Proxy, error) } func (mapper *PortMapper) Map(ip net.IP, port int, backendAddr net.Addr) error { + mapKey := (&net.TCPAddr{Port: port, IP: ip}).String() + if _, exists := mapper.tcpProxies[mapKey]; exists { + return fmt.Errorf("Port %s is already in use", mapKey) + } + if _, isTCP := backendAddr.(*net.TCPAddr); isTCP { backendPort := backendAddr.(*net.TCPAddr).Port backendIP := backendAddr.(*net.TCPAddr).IP @@ -243,13 +249,13 @@ func (mapper *PortMapper) Map(ip net.IP, port int, backendAddr net.Addr) error { return err } } - mapper.tcpMapping[port] = backendAddr.(*net.TCPAddr) - proxy, err := proxy.NewProxy(&net.TCPAddr{IP: ip, Port: port}, backendAddr) + mapper.tcpMapping[mapKey] = backendAddr.(*net.TCPAddr) + proxy, err := mapper.proxyFactoryFunc(&net.TCPAddr{IP: ip, Port: port}, backendAddr) if err != nil { mapper.Unmap(ip, port, "tcp") return err } - mapper.tcpProxies[port] = proxy + mapper.tcpProxies[mapKey] = proxy go proxy.Run() } else { backendPort := backendAddr.(*net.UDPAddr).Port @@ -259,49 +265,50 @@ func (mapper *PortMapper) Map(ip net.IP, port int, backendAddr net.Addr) error { return err } } - mapper.udpMapping[port] = backendAddr.(*net.UDPAddr) - proxy, err := proxy.NewProxy(&net.UDPAddr{IP: ip, Port: port}, backendAddr) + mapper.udpMapping[mapKey] = backendAddr.(*net.UDPAddr) + proxy, err := mapper.proxyFactoryFunc(&net.UDPAddr{IP: ip, Port: port}, backendAddr) if err != nil { mapper.Unmap(ip, port, "udp") return err } - mapper.udpProxies[port] = proxy + mapper.udpProxies[mapKey] = proxy go proxy.Run() } return nil } func (mapper *PortMapper) Unmap(ip net.IP, port int, proto string) error { + mapKey := (&net.TCPAddr{Port: port, IP: ip}).String() if proto == "tcp" { - backendAddr, ok := mapper.tcpMapping[port] + backendAddr, ok := mapper.tcpMapping[mapKey] if !ok { - return fmt.Errorf("Port tcp/%v is not mapped", port) + return fmt.Errorf("Port tcp/%s is not mapped", mapKey) } - if proxy, exists := mapper.tcpProxies[port]; exists { + if proxy, exists := mapper.tcpProxies[mapKey]; exists { proxy.Close() - delete(mapper.tcpProxies, port) + delete(mapper.tcpProxies, mapKey) } if mapper.iptables != nil { if err := mapper.iptables.Forward(iptables.Delete, ip, port, proto, backendAddr.IP.String(), backendAddr.Port); err != nil { return err } } - delete(mapper.tcpMapping, port) + delete(mapper.tcpMapping, mapKey) } else { - backendAddr, ok := mapper.udpMapping[port] + backendAddr, ok := mapper.udpMapping[mapKey] if !ok { - return fmt.Errorf("Port udp/%v is not mapped", port) + return fmt.Errorf("Port udp/%s is not mapped", mapKey) } - if proxy, exists := mapper.udpProxies[port]; exists { + if proxy, exists := mapper.udpProxies[mapKey]; exists { proxy.Close() - delete(mapper.udpProxies, port) + delete(mapper.udpProxies, mapKey) } if mapper.iptables != nil { if err := mapper.iptables.Forward(iptables.Delete, ip, port, proto, backendAddr.IP.String(), backendAddr.Port); err != nil { return err } } - delete(mapper.udpMapping, port) + delete(mapper.udpMapping, mapKey) } return nil } @@ -321,12 +328,13 @@ func newPortMapper(config *DaemonConfig) (*PortMapper, error) { } mapper := &PortMapper{ - tcpMapping: make(map[int]*net.TCPAddr), - tcpProxies: make(map[int]proxy.Proxy), - udpMapping: make(map[int]*net.UDPAddr), - udpProxies: make(map[int]proxy.Proxy), - iptables: chain, - defaultIp: config.DefaultIp, + tcpMapping: make(map[string]*net.TCPAddr), + tcpProxies: make(map[string]proxy.Proxy), + udpMapping: make(map[string]*net.UDPAddr), + udpProxies: make(map[string]proxy.Proxy), + iptables: chain, + defaultIp: config.DefaultIp, + proxyFactoryFunc: proxy.NewProxy, } return mapper, nil } @@ -334,7 +342,7 @@ func newPortMapper(config *DaemonConfig) (*PortMapper, error) { // Port allocator: Automatically allocate and release networking ports type PortAllocator struct { sync.Mutex - inUse map[int]struct{} + inUse map[string]struct{} fountain chan int quit chan bool } @@ -354,20 +362,22 @@ func (alloc *PortAllocator) runFountain() { } // FIXME: Release can no longer fail, change its prototype to reflect that. -func (alloc *PortAllocator) Release(port int) error { +func (alloc *PortAllocator) Release(addr net.IP, port int) error { + mapKey := (&net.TCPAddr{Port: port, IP: addr}).String() utils.Debugf("Releasing %d", port) alloc.Lock() - delete(alloc.inUse, port) + delete(alloc.inUse, mapKey) alloc.Unlock() return nil } -func (alloc *PortAllocator) Acquire(port int) (int, error) { - utils.Debugf("Acquiring %d", port) +func (alloc *PortAllocator) Acquire(addr net.IP, port int) (int, error) { + mapKey := (&net.TCPAddr{Port: port, IP: addr}).String() + utils.Debugf("Acquiring %s", mapKey) if port == 0 { // Allocate a port from the fountain for port := range alloc.fountain { - if _, err := alloc.Acquire(port); err == nil { + if _, err := alloc.Acquire(addr, port); err == nil { return port, nil } } @@ -375,10 +385,10 @@ func (alloc *PortAllocator) Acquire(port int) (int, error) { } alloc.Lock() defer alloc.Unlock() - if _, inUse := alloc.inUse[port]; inUse { + if _, inUse := alloc.inUse[mapKey]; inUse { return -1, fmt.Errorf("Port already in use: %d", port) } - alloc.inUse[port] = struct{}{} + alloc.inUse[mapKey] = struct{}{} return port, nil } @@ -391,7 +401,7 @@ func (alloc *PortAllocator) Close() error { func newPortAllocator() (*PortAllocator, error) { allocator := &PortAllocator{ - inUse: make(map[int]struct{}), + inUse: make(map[string]struct{}), fountain: make(chan int), quit: make(chan bool), } @@ -546,25 +556,25 @@ func (iface *NetworkInterface) AllocatePort(port Port, binding PortBinding) (*Na hostPort, _ := parsePort(nat.Binding.HostPort) if nat.Port.Proto() == "tcp" { - extPort, err := iface.manager.tcpPortAllocator.Acquire(hostPort) + extPort, err := iface.manager.tcpPortAllocator.Acquire(ip, hostPort) if err != nil { return nil, err } backend := &net.TCPAddr{IP: iface.IPNet.IP, Port: containerPort} if err := iface.manager.portMapper.Map(ip, extPort, backend); err != nil { - iface.manager.tcpPortAllocator.Release(extPort) + iface.manager.tcpPortAllocator.Release(ip, extPort) return nil, err } nat.Binding.HostPort = strconv.Itoa(extPort) } else { - extPort, err := iface.manager.udpPortAllocator.Acquire(hostPort) + extPort, err := iface.manager.udpPortAllocator.Acquire(ip, hostPort) if err != nil { return nil, err } backend := &net.UDPAddr{IP: iface.IPNet.IP, Port: containerPort} if err := iface.manager.portMapper.Map(ip, extPort, backend); err != nil { - iface.manager.udpPortAllocator.Release(extPort) + iface.manager.udpPortAllocator.Release(ip, extPort) return nil, err } nat.Binding.HostPort = strconv.Itoa(extPort) @@ -596,16 +606,19 @@ func (iface *NetworkInterface) Release() { continue } ip := net.ParseIP(nat.Binding.HostIp) - utils.Debugf("Unmaping %s/%s", nat.Port.Proto, nat.Binding.HostPort) + utils.Debugf("Unmaping %s/%s:%s", nat.Port.Proto, ip.String(), nat.Binding.HostPort) if err := iface.manager.portMapper.Unmap(ip, hostPort, nat.Port.Proto()); err != nil { log.Printf("Unable to unmap port %s: %s", nat, err) } + if nat.Port.Proto() == "tcp" { - if err := iface.manager.tcpPortAllocator.Release(hostPort); err != nil { + if err := iface.manager.tcpPortAllocator.Release(ip, hostPort); err != nil { log.Printf("Unable to release port %s", nat) } - } else if err := iface.manager.udpPortAllocator.Release(hostPort); err != nil { - log.Printf("Unable to release port %s: %s", nat, err) + } else if nat.Port.Proto() == "udp" { + if err := iface.manager.tcpPortAllocator.Release(ip, hostPort); err != nil { + log.Printf("Unable to release port %s: %s", nat, err) + } } } @@ -732,6 +745,7 @@ func newNetworkManager(config *DaemonConfig) (*NetworkManager, error) { if err != nil { return nil, err } + udpPortAllocator, err := newPortAllocator() if err != nil { return nil, err diff --git a/network_test.go b/network_test.go index e2631ddcb7..184b497938 100644 --- a/network_test.go +++ b/network_test.go @@ -1,42 +1,52 @@ package docker import ( + "github.com/dotcloud/docker/iptables" + "github.com/dotcloud/docker/proxy" "net" "testing" ) func TestPortAllocation(t *testing.T) { + ip := net.ParseIP("192.168.0.1") + ip2 := net.ParseIP("192.168.0.2") allocator, err := newPortAllocator() if err != nil { t.Fatal(err) } - if port, err := allocator.Acquire(80); err != nil { + if port, err := allocator.Acquire(ip, 80); err != nil { t.Fatal(err) } else if port != 80 { t.Fatalf("Acquire(80) should return 80, not %d", port) } - port, err := allocator.Acquire(0) + port, err := allocator.Acquire(ip, 0) if err != nil { t.Fatal(err) } if port <= 0 { t.Fatalf("Acquire(0) should return a non-zero port") } - if _, err := allocator.Acquire(port); err == nil { + if _, err := allocator.Acquire(ip, port); err == nil { t.Fatalf("Acquiring a port already in use should return an error") } - if newPort, err := allocator.Acquire(0); err != nil { + if newPort, err := allocator.Acquire(ip, 0); err != nil { t.Fatal(err) } else if newPort == port { t.Fatalf("Acquire(0) allocated the same port twice: %d", port) } - if _, err := allocator.Acquire(80); err == nil { + if _, err := allocator.Acquire(ip, 80); err == nil { t.Fatalf("Acquiring a port already in use should return an error") } - if err := allocator.Release(80); err != nil { + if _, err := allocator.Acquire(ip2, 80); err != nil { + t.Fatalf("It should be possible to allocate the same port on a different interface") + } + if _, err := allocator.Acquire(ip2, 80); err == nil { + t.Fatalf("Acquiring a port already in use should return an error") + } + if err := allocator.Release(ip, 80); err != nil { t.Fatal(err) } - if _, err := allocator.Acquire(80); err != nil { + if _, err := allocator.Acquire(ip, 80); err != nil { t.Fatal(err) } } @@ -311,3 +321,66 @@ func TestCheckNameserverOverlaps(t *testing.T) { t.Fatalf("%s should not overlap %v but it does", netX, nameservers) } } + +type StubProxy struct { + frontendAddr *net.Addr + backendAddr *net.Addr +} + +func (proxy *StubProxy) Run() {} +func (proxy *StubProxy) Close() {} +func (proxy *StubProxy) FrontendAddr() net.Addr { return *proxy.frontendAddr } +func (proxy *StubProxy) BackendAddr() net.Addr { return *proxy.backendAddr } + +func NewStubProxy(frontendAddr, backendAddr net.Addr) (proxy.Proxy, error) { + return &StubProxy{ + frontendAddr: &frontendAddr, + backendAddr: &backendAddr, + }, nil +} + +func TestPortMapper(t *testing.T) { + var chain *iptables.Chain + mapper := &PortMapper{ + tcpMapping: make(map[string]*net.TCPAddr), + tcpProxies: make(map[string]proxy.Proxy), + udpMapping: make(map[string]*net.UDPAddr), + udpProxies: make(map[string]proxy.Proxy), + iptables: chain, + defaultIp: net.IP("0.0.0.0"), + proxyFactoryFunc: NewStubProxy, + } + + dstIp1 := net.ParseIP("192.168.0.1") + dstIp2 := net.ParseIP("192.168.0.2") + srcAddr1 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.1")} + srcAddr2 := &net.TCPAddr{Port: 1080, IP: net.ParseIP("172.16.0.2")} + + if err := mapper.Map(dstIp1, 80, srcAddr1); err != nil { + t.Fatalf("Failed to allocate port: %s", err) + } + + if mapper.Map(dstIp1, 80, srcAddr1) == nil { + t.Fatalf("Port is in use - mapping should have failed") + } + + if mapper.Map(dstIp1, 80, srcAddr2) == nil { + t.Fatalf("Port is in use - mapping should have failed") + } + + if err := mapper.Map(dstIp2, 80, srcAddr2); err != nil { + t.Fatalf("Failed to allocate port: %s", err) + } + + if mapper.Unmap(dstIp1, 80, "tcp") != nil { + t.Fatalf("Failed to release port") + } + + if mapper.Unmap(dstIp2, 80, "tcp") != nil { + t.Fatalf("Failed to release port") + } + + if mapper.Unmap(dstIp2, 80, "tcp") == nil { + t.Fatalf("Port already released, but no error reported") + } +}