diff --git a/networkdriver/portallocator/allocator.go b/networkdriver/portallocator/allocator.go index 91036de199..5b6befd1da 100644 --- a/networkdriver/portallocator/allocator.go +++ b/networkdriver/portallocator/allocator.go @@ -8,7 +8,7 @@ import ( "sync" ) -type networkSet map[iPNet]iPSet +type networkSet map[iPNet]*iPSet type iPNet struct { IP string @@ -44,8 +44,8 @@ func RegisterNetwork(network *net.IPNet) error { } n := newIPNet(network) - allocatedIPs[n] = iPSet{} - availableIPS[n] = iPSet{} + allocatedIPs[n] = &iPSet{} + availableIPS[n] = &iPSet{} return nil } @@ -72,13 +72,18 @@ func ReleaseIP(network *net.IPNet, ip *net.IP) error { lock.Lock() defer lock.Unlock() - n := newIPNet(network) - existing := allocatedIPs[n] + var ( + first, _ = networkRange(network) + base = ipToInt(&first) + n = newIPNet(network) + existing = allocatedIPs[n] + available = availableIPS[n] + i = ipToInt(ip) + pos = i - base + ) - i := ipToInt(ip) - existing.Remove(int(i)) - available := availableIPS[n] - available.Push(int(i)) + existing.Remove(int(pos)) + available.Push(int(pos)) return nil } @@ -86,29 +91,43 @@ func ReleaseIP(network *net.IPNet, ip *net.IP) error { func getNextIp(network *net.IPNet) (*net.IP, error) { var ( n = newIPNet(network) + ownIP = ipToInt(&network.IP) available = availableIPS[n] - next = available.Pop() allocated = allocatedIPs[n] - ownIP = int(ipToInt(&network.IP)) + + first, _ = networkRange(network) + base = ipToInt(&first) + + pos = int32(available.Pop()) ) - if next != 0 { - ip := intToIP(int32(next)) - allocated.Push(int(next)) + // We pop and push the position not the ip + if pos != 0 { + ip := intToIP(int32(base + pos)) + allocated.Push(int(pos)) + return ip, nil } - size := int(networkSize(network.Mask)) - next = allocated.PullBack() + 1 - // size -1 for the broadcast address, -1 for the gateway address - for i := 0; i < size-2; i++ { + var ( + size = int(networkSize(network.Mask)) + max = int32(size - 2) // size -1 for the broadcast address, -1 for the gateway address + ) + + if pos = int32(allocated.PullBack()); pos == 0 { + pos = 1 + } + + for i := int32(0); i < max; i++ { + next := int32(base + pos) + pos = pos%max + 1 + if next == ownIP { - next++ continue } - ip := intToIP(int32(next)) - allocated.Push(next) + ip := intToIP(next) + allocated.Push(int(pos)) return ip, nil } @@ -117,6 +136,7 @@ func getNextIp(network *net.IPNet) (*net.IP, error) { func registerIP(network *net.IPNet, ip *net.IP) error { existing := allocatedIPs[newIPNet(network)] + // checking position not ip if existing.Exists(int(ipToInt(ip))) { return ErrIPAlreadyAllocated } diff --git a/networkdriver/portallocator/allocator_test.go b/networkdriver/portallocator/allocator_test.go index 570f415780..bcdcfa66b8 100644 --- a/networkdriver/portallocator/allocator_test.go +++ b/networkdriver/portallocator/allocator_test.go @@ -1,11 +1,18 @@ package ipallocator import ( + "fmt" "net" "testing" ) +func reset() { + allocatedIPs = networkSet{} + availableIPS = networkSet{} +} + func TestRegisterNetwork(t *testing.T) { + defer reset() network := &net.IPNet{ IP: []byte{192, 168, 0, 1}, Mask: []byte{255, 255, 255, 0}, @@ -24,3 +31,133 @@ func TestRegisterNetwork(t *testing.T) { t.Fatal("IPNet should exist in available IPs") } } + +func TestRegisterTwoNetworks(t *testing.T) { + defer reset() + network := &net.IPNet{ + IP: []byte{192, 168, 0, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network); err != nil { + t.Fatal(err) + } + + network2 := &net.IPNet{ + IP: []byte{10, 1, 42, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network2); err != nil { + t.Fatal(err) + } +} + +func TestRegisterNetworkThatExists(t *testing.T) { + defer reset() + network := &net.IPNet{ + IP: []byte{192, 168, 0, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network); err != nil { + t.Fatal(err) + } + + if err := RegisterNetwork(network); err != ErrNetworkAlreadyRegisterd { + t.Fatalf("Expected error of %s got %s", ErrNetworkAlreadyRegisterd, err) + } +} + +func TestRequestNewIps(t *testing.T) { + defer reset() + network := &net.IPNet{ + IP: []byte{192, 168, 0, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network); err != nil { + t.Fatal(err) + } + + for i := 2; i < 10; i++ { + ip, err := RequestIP(network, nil) + if err != nil { + t.Fatal(err) + } + + if expected := fmt.Sprintf("192.168.0.%d", i); ip.String() != expected { + t.Fatalf("Expected ip %s got %s", expected, ip.String()) + } + } +} + +func TestReleaseIp(t *testing.T) { + defer reset() + network := &net.IPNet{ + IP: []byte{192, 168, 0, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network); err != nil { + t.Fatal(err) + } + + ip, err := RequestIP(network, nil) + if err != nil { + t.Fatal(err) + } + + if err := ReleaseIP(network, ip); err != nil { + t.Fatal(err) + } +} + +func TestGetReleasedIp(t *testing.T) { + defer reset() + network := &net.IPNet{ + IP: []byte{192, 168, 0, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network); err != nil { + t.Fatal(err) + } + + ip, err := RequestIP(network, nil) + if err != nil { + t.Fatal(err) + } + + value := ip.String() + if err := ReleaseIP(network, ip); err != nil { + t.Fatal(err) + } + + ip, err = RequestIP(network, nil) + if err != nil { + t.Fatal(err) + } + + if ip.String() != value { + t.Fatalf("Expected to receive same ip %s got %s", value, ip.String()) + } +} + +func TestRequesetSpecificIp(t *testing.T) { + defer reset() + network := &net.IPNet{ + IP: []byte{192, 168, 0, 1}, + Mask: []byte{255, 255, 255, 0}, + } + + if err := RegisterNetwork(network); err != nil { + t.Fatal(err) + } + + ip := net.ParseIP("192.168.1.5") + + if _, err := RequestIP(network, &ip); err != nil { + t.Fatal(err) + } +} diff --git a/networkdriver/portallocator/ipset.go b/networkdriver/portallocator/ipset.go index 42b545b2d7..43d54691d1 100644 --- a/networkdriver/portallocator/ipset.go +++ b/networkdriver/portallocator/ipset.go @@ -82,8 +82,3 @@ func (s *iPSet) Remove(elem int) { } } } - -// Len returns the length of the list. -func (s *iPSet) Len() int { - return len(s.set) -}