mirror of
				https://github.com/moby/moby.git
				synced 2022-11-09 12:21:53 -05:00 
			
		
		
		
	Add IPVS netlink support
This PR adds netlink support to manipulate ipvs configuration. Signed-off-by: Jana Radhakrishnan <mrjana@docker.com>
This commit is contained in:
		
							parent
							
								
									ac18cc4b8f
								
							
						
					
					
						commit
						4b549ce428
					
				
					 4 changed files with 791 additions and 0 deletions
				
			
		
							
								
								
									
										130
									
								
								libnetwork/ipvs/constants.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								libnetwork/ipvs/constants.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,130 @@
 | 
			
		|||
// +build linux
 | 
			
		||||
 | 
			
		||||
package ipvs
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	genlCtrlID = 0x10
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GENL control commands
 | 
			
		||||
const (
 | 
			
		||||
	genlCtrlCmdUnspec uint8 = iota
 | 
			
		||||
	genlCtrlCmdNewFamily
 | 
			
		||||
	genlCtrlCmdDelFamily
 | 
			
		||||
	genlCtrlCmdGetFamily
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// GENL family attributes
 | 
			
		||||
const (
 | 
			
		||||
	genlCtrlAttrUnspec int = iota
 | 
			
		||||
	genlCtrlAttrFamilyID
 | 
			
		||||
	genlCtrlAttrFamilyName
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// IPVS genl commands
 | 
			
		||||
const (
 | 
			
		||||
	ipvsCmdUnspec uint8 = iota
 | 
			
		||||
	ipvsCmdNewService
 | 
			
		||||
	ipvsCmdSetService
 | 
			
		||||
	ipvsCmdDelService
 | 
			
		||||
	ipvsCmdGetService
 | 
			
		||||
	ipvsCmdNewDest
 | 
			
		||||
	ipvsCmdSetDest
 | 
			
		||||
	ipvsCmdDelDest
 | 
			
		||||
	ipvsCmdGetDest
 | 
			
		||||
	ipvsCmdNewDaemon
 | 
			
		||||
	ipvsCmdDelDaemon
 | 
			
		||||
	ipvsCmdGetDaemon
 | 
			
		||||
	ipvsCmdSetConfig
 | 
			
		||||
	ipvsCmdGetConfig
 | 
			
		||||
	ipvsCmdSetInfo
 | 
			
		||||
	ipvsCmdGetInfo
 | 
			
		||||
	ipvsCmdZero
 | 
			
		||||
	ipvsCmdFlush
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Attributes used in the first level of commands
 | 
			
		||||
const (
 | 
			
		||||
	ipvsCmdAttrUnspec int = iota
 | 
			
		||||
	ipvsCmdAttrService
 | 
			
		||||
	ipvsCmdAttrDest
 | 
			
		||||
	ipvsCmdAttrDaemon
 | 
			
		||||
	ipvsCmdAttrTimeoutTCP
 | 
			
		||||
	ipvsCmdAttrTimeoutTCPFin
 | 
			
		||||
	ipvsCmdAttrTimeoutUDP
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Attributes used to describe a service. Used inside nested attribute
 | 
			
		||||
// ipvsCmdAttrService
 | 
			
		||||
const (
 | 
			
		||||
	ipvsSvcAttrUnspec int = iota
 | 
			
		||||
	ipvsSvcAttrAddressFamily
 | 
			
		||||
	ipvsSvcAttrProtocol
 | 
			
		||||
	ipvsSvcAttrAddress
 | 
			
		||||
	ipvsSvcAttrPort
 | 
			
		||||
	ipvsSvcAttrFWMark
 | 
			
		||||
	ipvsSvcAttrSchedName
 | 
			
		||||
	ipvsSvcAttrFlags
 | 
			
		||||
	ipvsSvcAttrTimeout
 | 
			
		||||
	ipvsSvcAttrNetmask
 | 
			
		||||
	ipvsSvcAttrStats
 | 
			
		||||
	ipvsSvcAttrPEName
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Attributes used to describe a destination (real server). Used
 | 
			
		||||
// inside nested attribute ipvsCmdAttrDest.
 | 
			
		||||
const (
 | 
			
		||||
	ipvsDestAttrUnspec int = iota
 | 
			
		||||
	ipvsDestAttrAddress
 | 
			
		||||
	ipvsDestAttrPort
 | 
			
		||||
	ipvsDestAttrForwardingMethod
 | 
			
		||||
	ipvsDestAttrWeight
 | 
			
		||||
	ipvsDestAttrUpperThreshold
 | 
			
		||||
	ipvsDestAttrLowerThreshold
 | 
			
		||||
	ipvsDestAttrActiveConnections
 | 
			
		||||
	ipvsDestAttrInactiveConnections
 | 
			
		||||
	ipvsDestAttrPersistentConnections
 | 
			
		||||
	ipvsDestAttrStats
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Destination forwarding methods
 | 
			
		||||
const (
 | 
			
		||||
	// ConnectionFlagFwdmask indicates the mask in the connection
 | 
			
		||||
	// flags which is used by forwarding method bits.
 | 
			
		||||
	ConnectionFlagFwdMask = 0x0007
 | 
			
		||||
 | 
			
		||||
	// ConnectionFlagMasq is used for masquerade forwarding method.
 | 
			
		||||
	ConnectionFlagMasq = 0x0000
 | 
			
		||||
 | 
			
		||||
	// ConnectionFlagLocalNode is used for local node forwarding
 | 
			
		||||
	// method.
 | 
			
		||||
	ConnectionFlagLocalNode = 0x0001
 | 
			
		||||
 | 
			
		||||
	// ConnectionFlagTunnel is used for tunnel mode forwarding
 | 
			
		||||
	// method.
 | 
			
		||||
	ConnectionFlagTunnel = 0x0002
 | 
			
		||||
 | 
			
		||||
	// ConnectionFlagDirectRoute is used for direct routing
 | 
			
		||||
	// forwarding method.
 | 
			
		||||
	ConnectionFlagDirectRoute = 0x0003
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// RoundRobin distributes jobs equally amongst the available
 | 
			
		||||
	// real servers.
 | 
			
		||||
	RoundRobin = "rr"
 | 
			
		||||
 | 
			
		||||
	// LeastConnection assigns more jobs to real servers with
 | 
			
		||||
	// fewer active jobs.
 | 
			
		||||
	LeastConnection = "lc"
 | 
			
		||||
 | 
			
		||||
	// DestinationHashing assigns jobs to servers through looking
 | 
			
		||||
	// up a statically assigned hash table by their destination IP
 | 
			
		||||
	// addresses.
 | 
			
		||||
	DestinationHashing = "dh"
 | 
			
		||||
 | 
			
		||||
	// SourceHashing assigns jobs to servers through looking up
 | 
			
		||||
	// a statically assigned hash table by their source IP
 | 
			
		||||
	// addresses.
 | 
			
		||||
	SourceHashing = "sh"
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										113
									
								
								libnetwork/ipvs/ipvs.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								libnetwork/ipvs/ipvs.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,113 @@
 | 
			
		|||
// +build linux
 | 
			
		||||
 | 
			
		||||
package ipvs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"syscall"
 | 
			
		||||
 | 
			
		||||
	"github.com/vishvananda/netlink/nl"
 | 
			
		||||
	"github.com/vishvananda/netns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Service defines an IPVS service in its entirety.
 | 
			
		||||
type Service struct {
 | 
			
		||||
	// Virtual service address.
 | 
			
		||||
	Address  net.IP
 | 
			
		||||
	Protocol uint16
 | 
			
		||||
	Port     uint16
 | 
			
		||||
	FWMark   uint32 // Firewall mark of the service.
 | 
			
		||||
 | 
			
		||||
	// Virtual service options.
 | 
			
		||||
	SchedName     string
 | 
			
		||||
	Flags         uint32
 | 
			
		||||
	Timeout       uint32
 | 
			
		||||
	Netmask       uint32
 | 
			
		||||
	AddressFamily uint16
 | 
			
		||||
	PEName        string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Destination defines an IPVS destination (real server) in its
 | 
			
		||||
// entirety.
 | 
			
		||||
type Destination struct {
 | 
			
		||||
	Address         net.IP
 | 
			
		||||
	Port            uint16
 | 
			
		||||
	Weight          int
 | 
			
		||||
	ConnectionFlags uint32
 | 
			
		||||
	AddressFamily   uint16
 | 
			
		||||
	UpperThreshold  uint32
 | 
			
		||||
	LowerThreshold  uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handle provides a namespace specific ipvs handle to program ipvs
 | 
			
		||||
// rules.
 | 
			
		||||
type Handle struct {
 | 
			
		||||
	sock *nl.NetlinkSocket
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New provides a new ipvs handle in the namespace pointed to by the
 | 
			
		||||
// passed path. It will return a valid handle or an error in case an
 | 
			
		||||
// error occured while creating the handle.
 | 
			
		||||
func New(path string) (*Handle, error) {
 | 
			
		||||
	setup()
 | 
			
		||||
 | 
			
		||||
	n := netns.None()
 | 
			
		||||
	if path != "" {
 | 
			
		||||
		var err error
 | 
			
		||||
		n, err = netns.GetFromPath(path)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sock, err := nl.GetNetlinkSocketAt(n, netns.None(), syscall.NETLINK_GENERIC)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		n.Close()
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Handle{sock: sock}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close closes the ipvs handle. The handle is invalid after Close
 | 
			
		||||
// returns.
 | 
			
		||||
func (i *Handle) Close() {
 | 
			
		||||
	if i.sock != nil {
 | 
			
		||||
		i.sock.Close()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewService creates a new ipvs service in the passed handle.
 | 
			
		||||
func (i *Handle) NewService(s *Service) error {
 | 
			
		||||
	return i.doCmd(s, nil, ipvsCmdNewService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateService updates an already existing service in the passed
 | 
			
		||||
// handle.
 | 
			
		||||
func (i *Handle) UpdateService(s *Service) error {
 | 
			
		||||
	return i.doCmd(s, nil, ipvsCmdSetService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DelService deletes an already existing service in the passed
 | 
			
		||||
// handle.
 | 
			
		||||
func (i *Handle) DelService(s *Service) error {
 | 
			
		||||
	return i.doCmd(s, nil, ipvsCmdDelService)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDestination creates an new real server in the passed ipvs
 | 
			
		||||
// service which should already be existing in the passed handle.
 | 
			
		||||
func (i *Handle) NewDestination(s *Service, d *Destination) error {
 | 
			
		||||
	return i.doCmd(s, d, ipvsCmdNewDest)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateDestination updates an already existing real server in the
 | 
			
		||||
// passed ipvs service in the passed handle.
 | 
			
		||||
func (i *Handle) UpdateDestination(s *Service, d *Destination) error {
 | 
			
		||||
	return i.doCmd(s, d, ipvsCmdSetDest)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DelDestination deletes an already existing real server in the
 | 
			
		||||
// passed ipvs service in the passed handle.
 | 
			
		||||
func (i *Handle) DelDestination(s *Service, d *Destination) error {
 | 
			
		||||
	return i.doCmd(s, d, ipvsCmdDelDest)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										321
									
								
								libnetwork/ipvs/ipvs_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										321
									
								
								libnetwork/ipvs/ipvs_test.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,321 @@
 | 
			
		|||
// +build linux
 | 
			
		||||
 | 
			
		||||
package ipvs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os/exec"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/docker/libnetwork/testutils"
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"github.com/vishvananda/netlink"
 | 
			
		||||
	"github.com/vishvananda/netlink/nl"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	schedMethods = []string{
 | 
			
		||||
		RoundRobin,
 | 
			
		||||
		LeastConnection,
 | 
			
		||||
		DestinationHashing,
 | 
			
		||||
		SourceHashing,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	protocols = []string{
 | 
			
		||||
		"TCP",
 | 
			
		||||
		"UDP",
 | 
			
		||||
		"FWM",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fwdMethods = []uint32{
 | 
			
		||||
		ConnectionFlagMasq,
 | 
			
		||||
		ConnectionFlagTunnel,
 | 
			
		||||
		ConnectionFlagDirectRoute,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fwdMethodStrings = []string{
 | 
			
		||||
		"Masq",
 | 
			
		||||
		"Tunnel",
 | 
			
		||||
		"Route",
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func checkDestination(t *testing.T, checkPresent bool, protocol, serviceAddress, realAddress, fwdMethod string) {
 | 
			
		||||
	var (
 | 
			
		||||
		realServerStart bool
 | 
			
		||||
		realServers     []string
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	out, err := exec.Command("ipvsadm", "-Ln").CombinedOutput()
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	for _, o := range strings.Split(string(out), "\n") {
 | 
			
		||||
		cmpStr := serviceAddress
 | 
			
		||||
		if protocol == "FWM" {
 | 
			
		||||
			cmpStr = " " + cmpStr
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if strings.Contains(o, cmpStr) {
 | 
			
		||||
			realServerStart = true
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if realServerStart {
 | 
			
		||||
			if !strings.Contains(o, "->") {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			realServers = append(realServers, o)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, r := range realServers {
 | 
			
		||||
		if strings.Contains(r, realAddress) {
 | 
			
		||||
			parts := strings.Fields(r)
 | 
			
		||||
			assert.Equal(t, fwdMethod, parts[2])
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if checkPresent {
 | 
			
		||||
		t.Fatalf("Did not find the destination %s fwdMethod %s in ipvs output", realAddress, fwdMethod)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func checkService(t *testing.T, checkPresent bool, protocol, schedMethod, serviceAddress string) {
 | 
			
		||||
	out, err := exec.Command("ipvsadm", "-Ln").CombinedOutput()
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	for _, o := range strings.Split(string(out), "\n") {
 | 
			
		||||
		cmpStr := serviceAddress
 | 
			
		||||
		if protocol == "FWM" {
 | 
			
		||||
			cmpStr = " " + cmpStr
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if strings.Contains(o, cmpStr) {
 | 
			
		||||
			parts := strings.Split(o, " ")
 | 
			
		||||
			assert.Equal(t, protocol, parts[0])
 | 
			
		||||
			assert.Equal(t, serviceAddress, parts[2])
 | 
			
		||||
			assert.Equal(t, schedMethod, parts[3])
 | 
			
		||||
 | 
			
		||||
			if !checkPresent {
 | 
			
		||||
				t.Fatalf("Did not expect the service %s in ipvs output", serviceAddress)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if checkPresent {
 | 
			
		||||
		t.Fatalf("Did not find the service %s in ipvs output", serviceAddress)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetFamily(t *testing.T) {
 | 
			
		||||
	if testutils.RunningOnCircleCI() {
 | 
			
		||||
		t.Skipf("Skipping as not supported on CIRCLE CI kernel")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	id, err := getIPVSFamily()
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	assert.NotEqual(t, 0, id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestService(t *testing.T) {
 | 
			
		||||
	if testutils.RunningOnCircleCI() {
 | 
			
		||||
		t.Skipf("Skipping as not supported on CIRCLE CI kernel")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer testutils.SetupTestOSContext(t)()
 | 
			
		||||
 | 
			
		||||
	i, err := New("")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	for _, protocol := range protocols {
 | 
			
		||||
		for _, schedMethod := range schedMethods {
 | 
			
		||||
			var serviceAddress string
 | 
			
		||||
 | 
			
		||||
			s := Service{
 | 
			
		||||
				AddressFamily: nl.FAMILY_V4,
 | 
			
		||||
				SchedName:     schedMethod,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch protocol {
 | 
			
		||||
			case "FWM":
 | 
			
		||||
				s.FWMark = 1234
 | 
			
		||||
				serviceAddress = fmt.Sprintf("%d", 1234)
 | 
			
		||||
			case "TCP":
 | 
			
		||||
				s.Protocol = syscall.IPPROTO_TCP
 | 
			
		||||
				s.Port = 80
 | 
			
		||||
				s.Address = net.ParseIP("1.2.3.4")
 | 
			
		||||
				s.Netmask = 0xFFFFFFFF
 | 
			
		||||
				serviceAddress = "1.2.3.4:80"
 | 
			
		||||
			case "UDP":
 | 
			
		||||
				s.Protocol = syscall.IPPROTO_UDP
 | 
			
		||||
				s.Port = 53
 | 
			
		||||
				s.Address = net.ParseIP("2.3.4.5")
 | 
			
		||||
				serviceAddress = "2.3.4.5:53"
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err := i.NewService(&s)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			checkService(t, true, protocol, schedMethod, serviceAddress)
 | 
			
		||||
			var lastMethod string
 | 
			
		||||
			for _, updateSchedMethod := range schedMethods {
 | 
			
		||||
				if updateSchedMethod == schedMethod {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				s.SchedName = updateSchedMethod
 | 
			
		||||
				err = i.UpdateService(&s)
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				checkService(t, true, protocol, updateSchedMethod, serviceAddress)
 | 
			
		||||
				lastMethod = updateSchedMethod
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = i.DelService(&s)
 | 
			
		||||
			checkService(t, false, protocol, lastMethod, serviceAddress)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func createDummyInterface(t *testing.T) {
 | 
			
		||||
	if testutils.RunningOnCircleCI() {
 | 
			
		||||
		t.Skipf("Skipping as not supported on CIRCLE CI kernel")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dummy := &netlink.Dummy{
 | 
			
		||||
		LinkAttrs: netlink.LinkAttrs{
 | 
			
		||||
			Name: "dummy",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := netlink.LinkAdd(dummy)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	dummyLink, err := netlink.LinkByName("dummy")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	ip, ipNet, err := net.ParseCIDR("10.1.1.1/24")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	ipNet.IP = ip
 | 
			
		||||
 | 
			
		||||
	ipAddr := &netlink.Addr{IPNet: ipNet, Label: ""}
 | 
			
		||||
	err = netlink.AddrAdd(dummyLink, ipAddr)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDestination(t *testing.T) {
 | 
			
		||||
	defer testutils.SetupTestOSContext(t)()
 | 
			
		||||
 | 
			
		||||
	createDummyInterface(t)
 | 
			
		||||
	i, err := New("")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	for _, protocol := range []string{"TCP"} {
 | 
			
		||||
		var serviceAddress string
 | 
			
		||||
 | 
			
		||||
		s := Service{
 | 
			
		||||
			AddressFamily: nl.FAMILY_V4,
 | 
			
		||||
			SchedName:     RoundRobin,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		switch protocol {
 | 
			
		||||
		case "FWM":
 | 
			
		||||
			s.FWMark = 1234
 | 
			
		||||
			serviceAddress = fmt.Sprintf("%d", 1234)
 | 
			
		||||
		case "TCP":
 | 
			
		||||
			s.Protocol = syscall.IPPROTO_TCP
 | 
			
		||||
			s.Port = 80
 | 
			
		||||
			s.Address = net.ParseIP("1.2.3.4")
 | 
			
		||||
			s.Netmask = 0xFFFFFFFF
 | 
			
		||||
			serviceAddress = "1.2.3.4:80"
 | 
			
		||||
		case "UDP":
 | 
			
		||||
			s.Protocol = syscall.IPPROTO_UDP
 | 
			
		||||
			s.Port = 53
 | 
			
		||||
			s.Address = net.ParseIP("2.3.4.5")
 | 
			
		||||
			serviceAddress = "2.3.4.5:53"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err := i.NewService(&s)
 | 
			
		||||
		assert.NoError(t, err)
 | 
			
		||||
		checkService(t, true, protocol, RoundRobin, serviceAddress)
 | 
			
		||||
 | 
			
		||||
		s.SchedName = ""
 | 
			
		||||
		for j, fwdMethod := range fwdMethods {
 | 
			
		||||
			d1 := Destination{
 | 
			
		||||
				AddressFamily:   nl.FAMILY_V4,
 | 
			
		||||
				Address:         net.ParseIP("10.1.1.2"),
 | 
			
		||||
				Port:            5000,
 | 
			
		||||
				Weight:          1,
 | 
			
		||||
				ConnectionFlags: fwdMethod,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			realAddress := "10.1.1.2:5000"
 | 
			
		||||
			err := i.NewDestination(&s, &d1)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j])
 | 
			
		||||
			d2 := Destination{
 | 
			
		||||
				AddressFamily:   nl.FAMILY_V4,
 | 
			
		||||
				Address:         net.ParseIP("10.1.1.3"),
 | 
			
		||||
				Port:            5000,
 | 
			
		||||
				Weight:          1,
 | 
			
		||||
				ConnectionFlags: fwdMethod,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			realAddress = "10.1.1.3:5000"
 | 
			
		||||
			err = i.NewDestination(&s, &d2)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j])
 | 
			
		||||
 | 
			
		||||
			d3 := Destination{
 | 
			
		||||
				AddressFamily:   nl.FAMILY_V4,
 | 
			
		||||
				Address:         net.ParseIP("10.1.1.4"),
 | 
			
		||||
				Port:            5000,
 | 
			
		||||
				Weight:          1,
 | 
			
		||||
				ConnectionFlags: fwdMethod,
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			realAddress = "10.1.1.4:5000"
 | 
			
		||||
			err = i.NewDestination(&s, &d3)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j])
 | 
			
		||||
 | 
			
		||||
			for m, updateFwdMethod := range fwdMethods {
 | 
			
		||||
				if updateFwdMethod == fwdMethod {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				d1.ConnectionFlags = updateFwdMethod
 | 
			
		||||
				realAddress = "10.1.1.2:5000"
 | 
			
		||||
				err = i.UpdateDestination(&s, &d1)
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m])
 | 
			
		||||
 | 
			
		||||
				d2.ConnectionFlags = updateFwdMethod
 | 
			
		||||
				realAddress = "10.1.1.3:5000"
 | 
			
		||||
				err = i.UpdateDestination(&s, &d2)
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m])
 | 
			
		||||
 | 
			
		||||
				d3.ConnectionFlags = updateFwdMethod
 | 
			
		||||
				realAddress = "10.1.1.4:5000"
 | 
			
		||||
				err = i.UpdateDestination(&s, &d3)
 | 
			
		||||
				assert.NoError(t, err)
 | 
			
		||||
				checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m])
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			err = i.DelDestination(&s, &d1)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			err = i.DelDestination(&s, &d2)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
			err = i.DelDestination(&s, &d3)
 | 
			
		||||
			assert.NoError(t, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										227
									
								
								libnetwork/ipvs/netlink.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								libnetwork/ipvs/netlink.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,227 @@
 | 
			
		|||
// +build linux
 | 
			
		||||
 | 
			
		||||
package ipvs
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"unsafe"
 | 
			
		||||
 | 
			
		||||
	"github.com/vishvananda/netlink/nl"
 | 
			
		||||
	"github.com/vishvananda/netns"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	native     = nl.NativeEndian()
 | 
			
		||||
	ipvsFamily int
 | 
			
		||||
	ipvsOnce   sync.Once
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type genlMsgHdr struct {
 | 
			
		||||
	cmd      uint8
 | 
			
		||||
	version  uint8
 | 
			
		||||
	reserved uint16
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ipvsFlags struct {
 | 
			
		||||
	flags uint32
 | 
			
		||||
	mask  uint32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) {
 | 
			
		||||
	return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0]))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (hdr *genlMsgHdr) Serialize() []byte {
 | 
			
		||||
	return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (hdr *genlMsgHdr) Len() int {
 | 
			
		||||
	return int(unsafe.Sizeof(*hdr))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *ipvsFlags) Serialize() []byte {
 | 
			
		||||
	return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f *ipvsFlags) Len() int {
 | 
			
		||||
	return int(unsafe.Sizeof(*f))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setup() {
 | 
			
		||||
	ipvsOnce.Do(func() {
 | 
			
		||||
		var err error
 | 
			
		||||
		ipvsFamily, err = getIPVSFamily()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			panic("could not get ipvs family")
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fillService(s *Service) nl.NetlinkRequestData {
 | 
			
		||||
	cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil)
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily))
 | 
			
		||||
	if s.FWMark != 0 {
 | 
			
		||||
		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark))
 | 
			
		||||
	} else {
 | 
			
		||||
		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol))
 | 
			
		||||
		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address))
 | 
			
		||||
 | 
			
		||||
		// Port needs to be in network byte order.
 | 
			
		||||
		portBuf := new(bytes.Buffer)
 | 
			
		||||
		binary.Write(portBuf, binary.BigEndian, s.Port)
 | 
			
		||||
		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName))
 | 
			
		||||
	if s.PEName != "" {
 | 
			
		||||
		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	f := &ipvsFlags{
 | 
			
		||||
		flags: s.Flags,
 | 
			
		||||
		mask:  0xFFFFFFFF,
 | 
			
		||||
	}
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize())
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout))
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask))
 | 
			
		||||
	return cmdAttr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func fillDestinaton(d *Destination) nl.NetlinkRequestData {
 | 
			
		||||
	cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil)
 | 
			
		||||
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address))
 | 
			
		||||
	// Port needs to be in network byte order.
 | 
			
		||||
	portBuf := new(bytes.Buffer)
 | 
			
		||||
	binary.Write(portBuf, binary.BigEndian, d.Port)
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes())
 | 
			
		||||
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask))
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight)))
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold))
 | 
			
		||||
	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold))
 | 
			
		||||
 | 
			
		||||
	return cmdAttr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error {
 | 
			
		||||
	req := newIPVSRequest(cmd)
 | 
			
		||||
	req.AddData(fillService(s))
 | 
			
		||||
 | 
			
		||||
	if d != nil {
 | 
			
		||||
		req.AddData(fillDestinaton(d))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := execute(i.sock, req, 0); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getIPVSFamily() (int, error) {
 | 
			
		||||
	sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily)
 | 
			
		||||
	req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS")))
 | 
			
		||||
 | 
			
		||||
	msgs, err := execute(sock, req, 0)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, m := range msgs {
 | 
			
		||||
		hdr := deserializeGenlMsg(m)
 | 
			
		||||
		attrs, err := nl.ParseRouteAttr(m[hdr.Len():])
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, attr := range attrs {
 | 
			
		||||
			switch int(attr.Attr.Type) {
 | 
			
		||||
			case genlCtrlAttrFamilyID:
 | 
			
		||||
				return int(native.Uint16(attr.Value[0:2])), nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return 0, fmt.Errorf("no family id in the netlink response")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func rawIPData(ip net.IP) []byte {
 | 
			
		||||
	family := nl.GetIPFamily(ip)
 | 
			
		||||
	if family == nl.FAMILY_V4 {
 | 
			
		||||
		return ip.To4()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ip
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newIPVSRequest(cmd uint8) *nl.NetlinkRequest {
 | 
			
		||||
	return newGenlRequest(ipvsFamily, cmd)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest {
 | 
			
		||||
	req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK)
 | 
			
		||||
	req.AddData(&genlMsgHdr{cmd: cmd, version: 1})
 | 
			
		||||
	return req
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		err error
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err := s.Send(req); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pid, err := s.GetPid()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res [][]byte
 | 
			
		||||
 | 
			
		||||
done:
 | 
			
		||||
	for {
 | 
			
		||||
		msgs, err := s.Receive()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		for _, m := range msgs {
 | 
			
		||||
			if m.Header.Seq != req.Seq {
 | 
			
		||||
				return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
 | 
			
		||||
			}
 | 
			
		||||
			if m.Header.Pid != pid {
 | 
			
		||||
				return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
 | 
			
		||||
			}
 | 
			
		||||
			if m.Header.Type == syscall.NLMSG_DONE {
 | 
			
		||||
				break done
 | 
			
		||||
			}
 | 
			
		||||
			if m.Header.Type == syscall.NLMSG_ERROR {
 | 
			
		||||
				error := int32(native.Uint32(m.Data[0:4]))
 | 
			
		||||
				if error == 0 {
 | 
			
		||||
					break done
 | 
			
		||||
				}
 | 
			
		||||
				return nil, syscall.Errno(-error)
 | 
			
		||||
			}
 | 
			
		||||
			if resType != 0 && m.Header.Type != resType {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			res = append(res, m.Data)
 | 
			
		||||
			if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
 | 
			
		||||
				break done
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue