package portallocator import ( "errors" "net" "sync" ) type ( portMap map[int]bool protocolMap map[string]portMap ipMapping map[string]protocolMap ) const ( BeginPortRange = 49153 EndPortRange = 65535 ) var ( ErrAllPortsAllocated = errors.New("all ports are allocated") ErrPortAlreadyAllocated = errors.New("port has already been allocated") ErrUnknownProtocol = errors.New("unknown protocol") ) var ( mutex sync.Mutex defaultIP = net.ParseIP("0.0.0.0") globalMap = ipMapping{} ) func RequestPort(ip net.IP, proto string, port int) (int, error) { mutex.Lock() defer mutex.Unlock() if err := validateProto(proto); err != nil { return 0, err } ip = getDefault(ip) mapping := getOrCreate(ip) if port > 0 { if !mapping[proto][port] { mapping[proto][port] = true return port, nil } else { return 0, ErrPortAlreadyAllocated } } else { port, err := findPort(ip, proto) if err != nil { return 0, err } return port, nil } } func ReleasePort(ip net.IP, proto string, port int) error { mutex.Lock() defer mutex.Unlock() ip = getDefault(ip) mapping := getOrCreate(ip) delete(mapping[proto], port) return nil } func ReleaseAll() error { mutex.Lock() defer mutex.Unlock() globalMap = ipMapping{} return nil } func getOrCreate(ip net.IP) protocolMap { ipstr := ip.String() if _, ok := globalMap[ipstr]; !ok { globalMap[ipstr] = protocolMap{ "tcp": portMap{}, "udp": portMap{}, } } return globalMap[ipstr] } func findPort(ip net.IP, proto string) (int, error) { port := BeginPortRange mapping := getOrCreate(ip) for mapping[proto][port] { port++ if port > EndPortRange { return 0, ErrAllPortsAllocated } } mapping[proto][port] = true return port, nil } func getDefault(ip net.IP) net.IP { if ip == nil { return defaultIP } return ip } func validateProto(proto string) error { if proto != "tcp" && proto != "udp" { return ErrUnknownProtocol } return nil }