From 4b549ce42854b17219cf42fd06f813e8cf5aa2db Mon Sep 17 00:00:00 2001 From: Jana Radhakrishnan Date: Tue, 24 May 2016 13:12:16 -0700 Subject: [PATCH] Add IPVS netlink support This PR adds netlink support to manipulate ipvs configuration. Signed-off-by: Jana Radhakrishnan --- libnetwork/ipvs/constants.go | 130 ++++++++++++++ libnetwork/ipvs/ipvs.go | 113 ++++++++++++ libnetwork/ipvs/ipvs_test.go | 321 +++++++++++++++++++++++++++++++++++ libnetwork/ipvs/netlink.go | 227 +++++++++++++++++++++++++ 4 files changed, 791 insertions(+) create mode 100644 libnetwork/ipvs/constants.go create mode 100644 libnetwork/ipvs/ipvs.go create mode 100644 libnetwork/ipvs/ipvs_test.go create mode 100644 libnetwork/ipvs/netlink.go diff --git a/libnetwork/ipvs/constants.go b/libnetwork/ipvs/constants.go new file mode 100644 index 0000000000..103e71a37c --- /dev/null +++ b/libnetwork/ipvs/constants.go @@ -0,0 +1,130 @@ +// +build linux + +package ipvs + +const ( + genlCtrlID = 0x10 +) + +// GENL control commands +const ( + genlCtrlCmdUnspec uint8 = iota + genlCtrlCmdNewFamily + genlCtrlCmdDelFamily + genlCtrlCmdGetFamily +) + +// GENL family attributes +const ( + genlCtrlAttrUnspec int = iota + genlCtrlAttrFamilyID + genlCtrlAttrFamilyName +) + +// IPVS genl commands +const ( + ipvsCmdUnspec uint8 = iota + ipvsCmdNewService + ipvsCmdSetService + ipvsCmdDelService + ipvsCmdGetService + ipvsCmdNewDest + ipvsCmdSetDest + ipvsCmdDelDest + ipvsCmdGetDest + ipvsCmdNewDaemon + ipvsCmdDelDaemon + ipvsCmdGetDaemon + ipvsCmdSetConfig + ipvsCmdGetConfig + ipvsCmdSetInfo + ipvsCmdGetInfo + ipvsCmdZero + ipvsCmdFlush +) + +// Attributes used in the first level of commands +const ( + ipvsCmdAttrUnspec int = iota + ipvsCmdAttrService + ipvsCmdAttrDest + ipvsCmdAttrDaemon + ipvsCmdAttrTimeoutTCP + ipvsCmdAttrTimeoutTCPFin + ipvsCmdAttrTimeoutUDP +) + +// Attributes used to describe a service. Used inside nested attribute +// ipvsCmdAttrService +const ( + ipvsSvcAttrUnspec int = iota + ipvsSvcAttrAddressFamily + ipvsSvcAttrProtocol + ipvsSvcAttrAddress + ipvsSvcAttrPort + ipvsSvcAttrFWMark + ipvsSvcAttrSchedName + ipvsSvcAttrFlags + ipvsSvcAttrTimeout + ipvsSvcAttrNetmask + ipvsSvcAttrStats + ipvsSvcAttrPEName +) + +// Attributes used to describe a destination (real server). Used +// inside nested attribute ipvsCmdAttrDest. +const ( + ipvsDestAttrUnspec int = iota + ipvsDestAttrAddress + ipvsDestAttrPort + ipvsDestAttrForwardingMethod + ipvsDestAttrWeight + ipvsDestAttrUpperThreshold + ipvsDestAttrLowerThreshold + ipvsDestAttrActiveConnections + ipvsDestAttrInactiveConnections + ipvsDestAttrPersistentConnections + ipvsDestAttrStats +) + +// Destination forwarding methods +const ( + // ConnectionFlagFwdmask indicates the mask in the connection + // flags which is used by forwarding method bits. + ConnectionFlagFwdMask = 0x0007 + + // ConnectionFlagMasq is used for masquerade forwarding method. + ConnectionFlagMasq = 0x0000 + + // ConnectionFlagLocalNode is used for local node forwarding + // method. + ConnectionFlagLocalNode = 0x0001 + + // ConnectionFlagTunnel is used for tunnel mode forwarding + // method. + ConnectionFlagTunnel = 0x0002 + + // ConnectionFlagDirectRoute is used for direct routing + // forwarding method. + ConnectionFlagDirectRoute = 0x0003 +) + +const ( + // RoundRobin distributes jobs equally amongst the available + // real servers. + RoundRobin = "rr" + + // LeastConnection assigns more jobs to real servers with + // fewer active jobs. + LeastConnection = "lc" + + // DestinationHashing assigns jobs to servers through looking + // up a statically assigned hash table by their destination IP + // addresses. + DestinationHashing = "dh" + + // SourceHashing assigns jobs to servers through looking up + // a statically assigned hash table by their source IP + // addresses. + SourceHashing = "sh" +) diff --git a/libnetwork/ipvs/ipvs.go b/libnetwork/ipvs/ipvs.go new file mode 100644 index 0000000000..8f0a0ab89a --- /dev/null +++ b/libnetwork/ipvs/ipvs.go @@ -0,0 +1,113 @@ +// +build linux + +package ipvs + +import ( + "net" + "syscall" + + "github.com/vishvananda/netlink/nl" + "github.com/vishvananda/netns" +) + +// Service defines an IPVS service in its entirety. +type Service struct { + // Virtual service address. + Address net.IP + Protocol uint16 + Port uint16 + FWMark uint32 // Firewall mark of the service. + + // Virtual service options. + SchedName string + Flags uint32 + Timeout uint32 + Netmask uint32 + AddressFamily uint16 + PEName string +} + +// Destination defines an IPVS destination (real server) in its +// entirety. +type Destination struct { + Address net.IP + Port uint16 + Weight int + ConnectionFlags uint32 + AddressFamily uint16 + UpperThreshold uint32 + LowerThreshold uint32 +} + +// Handle provides a namespace specific ipvs handle to program ipvs +// rules. +type Handle struct { + sock *nl.NetlinkSocket +} + +// New provides a new ipvs handle in the namespace pointed to by the +// passed path. It will return a valid handle or an error in case an +// error occured while creating the handle. +func New(path string) (*Handle, error) { + setup() + + n := netns.None() + if path != "" { + var err error + n, err = netns.GetFromPath(path) + if err != nil { + return nil, err + } + } + + sock, err := nl.GetNetlinkSocketAt(n, netns.None(), syscall.NETLINK_GENERIC) + if err != nil { + n.Close() + return nil, err + } + + return &Handle{sock: sock}, nil +} + +// Close closes the ipvs handle. The handle is invalid after Close +// returns. +func (i *Handle) Close() { + if i.sock != nil { + i.sock.Close() + } +} + +// NewService creates a new ipvs service in the passed handle. +func (i *Handle) NewService(s *Service) error { + return i.doCmd(s, nil, ipvsCmdNewService) +} + +// UpdateService updates an already existing service in the passed +// handle. +func (i *Handle) UpdateService(s *Service) error { + return i.doCmd(s, nil, ipvsCmdSetService) +} + +// DelService deletes an already existing service in the passed +// handle. +func (i *Handle) DelService(s *Service) error { + return i.doCmd(s, nil, ipvsCmdDelService) +} + +// NewDestination creates an new real server in the passed ipvs +// service which should already be existing in the passed handle. +func (i *Handle) NewDestination(s *Service, d *Destination) error { + return i.doCmd(s, d, ipvsCmdNewDest) +} + +// UpdateDestination updates an already existing real server in the +// passed ipvs service in the passed handle. +func (i *Handle) UpdateDestination(s *Service, d *Destination) error { + return i.doCmd(s, d, ipvsCmdSetDest) +} + +// DelDestination deletes an already existing real server in the +// passed ipvs service in the passed handle. +func (i *Handle) DelDestination(s *Service, d *Destination) error { + return i.doCmd(s, d, ipvsCmdDelDest) +} diff --git a/libnetwork/ipvs/ipvs_test.go b/libnetwork/ipvs/ipvs_test.go new file mode 100644 index 0000000000..e0d110dd14 --- /dev/null +++ b/libnetwork/ipvs/ipvs_test.go @@ -0,0 +1,321 @@ +// +build linux + +package ipvs + +import ( + "fmt" + "net" + "os/exec" + "strings" + "syscall" + "testing" + + "github.com/docker/libnetwork/testutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" +) + +var ( + schedMethods = []string{ + RoundRobin, + LeastConnection, + DestinationHashing, + SourceHashing, + } + + protocols = []string{ + "TCP", + "UDP", + "FWM", + } + + fwdMethods = []uint32{ + ConnectionFlagMasq, + ConnectionFlagTunnel, + ConnectionFlagDirectRoute, + } + + fwdMethodStrings = []string{ + "Masq", + "Tunnel", + "Route", + } +) + +func checkDestination(t *testing.T, checkPresent bool, protocol, serviceAddress, realAddress, fwdMethod string) { + var ( + realServerStart bool + realServers []string + ) + + out, err := exec.Command("ipvsadm", "-Ln").CombinedOutput() + require.NoError(t, err) + + for _, o := range strings.Split(string(out), "\n") { + cmpStr := serviceAddress + if protocol == "FWM" { + cmpStr = " " + cmpStr + } + + if strings.Contains(o, cmpStr) { + realServerStart = true + continue + } + + if realServerStart { + if !strings.Contains(o, "->") { + break + } + + realServers = append(realServers, o) + } + } + + for _, r := range realServers { + if strings.Contains(r, realAddress) { + parts := strings.Fields(r) + assert.Equal(t, fwdMethod, parts[2]) + return + } + } + + if checkPresent { + t.Fatalf("Did not find the destination %s fwdMethod %s in ipvs output", realAddress, fwdMethod) + } +} + +func checkService(t *testing.T, checkPresent bool, protocol, schedMethod, serviceAddress string) { + out, err := exec.Command("ipvsadm", "-Ln").CombinedOutput() + require.NoError(t, err) + + for _, o := range strings.Split(string(out), "\n") { + cmpStr := serviceAddress + if protocol == "FWM" { + cmpStr = " " + cmpStr + } + + if strings.Contains(o, cmpStr) { + parts := strings.Split(o, " ") + assert.Equal(t, protocol, parts[0]) + assert.Equal(t, serviceAddress, parts[2]) + assert.Equal(t, schedMethod, parts[3]) + + if !checkPresent { + t.Fatalf("Did not expect the service %s in ipvs output", serviceAddress) + } + + return + } + } + + if checkPresent { + t.Fatalf("Did not find the service %s in ipvs output", serviceAddress) + } +} + +func TestGetFamily(t *testing.T) { + if testutils.RunningOnCircleCI() { + t.Skipf("Skipping as not supported on CIRCLE CI kernel") + } + + id, err := getIPVSFamily() + require.NoError(t, err) + assert.NotEqual(t, 0, id) +} + +func TestService(t *testing.T) { + if testutils.RunningOnCircleCI() { + t.Skipf("Skipping as not supported on CIRCLE CI kernel") + } + + defer testutils.SetupTestOSContext(t)() + + i, err := New("") + require.NoError(t, err) + + for _, protocol := range protocols { + for _, schedMethod := range schedMethods { + var serviceAddress string + + s := Service{ + AddressFamily: nl.FAMILY_V4, + SchedName: schedMethod, + } + + switch protocol { + case "FWM": + s.FWMark = 1234 + serviceAddress = fmt.Sprintf("%d", 1234) + case "TCP": + s.Protocol = syscall.IPPROTO_TCP + s.Port = 80 + s.Address = net.ParseIP("1.2.3.4") + s.Netmask = 0xFFFFFFFF + serviceAddress = "1.2.3.4:80" + case "UDP": + s.Protocol = syscall.IPPROTO_UDP + s.Port = 53 + s.Address = net.ParseIP("2.3.4.5") + serviceAddress = "2.3.4.5:53" + } + + err := i.NewService(&s) + assert.NoError(t, err) + checkService(t, true, protocol, schedMethod, serviceAddress) + var lastMethod string + for _, updateSchedMethod := range schedMethods { + if updateSchedMethod == schedMethod { + continue + } + + s.SchedName = updateSchedMethod + err = i.UpdateService(&s) + assert.NoError(t, err) + checkService(t, true, protocol, updateSchedMethod, serviceAddress) + lastMethod = updateSchedMethod + } + + err = i.DelService(&s) + checkService(t, false, protocol, lastMethod, serviceAddress) + } + } + +} + +func createDummyInterface(t *testing.T) { + if testutils.RunningOnCircleCI() { + t.Skipf("Skipping as not supported on CIRCLE CI kernel") + } + + dummy := &netlink.Dummy{ + LinkAttrs: netlink.LinkAttrs{ + Name: "dummy", + }, + } + + err := netlink.LinkAdd(dummy) + require.NoError(t, err) + + dummyLink, err := netlink.LinkByName("dummy") + require.NoError(t, err) + + ip, ipNet, err := net.ParseCIDR("10.1.1.1/24") + require.NoError(t, err) + + ipNet.IP = ip + + ipAddr := &netlink.Addr{IPNet: ipNet, Label: ""} + err = netlink.AddrAdd(dummyLink, ipAddr) + require.NoError(t, err) +} + +func TestDestination(t *testing.T) { + defer testutils.SetupTestOSContext(t)() + + createDummyInterface(t) + i, err := New("") + require.NoError(t, err) + + for _, protocol := range []string{"TCP"} { + var serviceAddress string + + s := Service{ + AddressFamily: nl.FAMILY_V4, + SchedName: RoundRobin, + } + + switch protocol { + case "FWM": + s.FWMark = 1234 + serviceAddress = fmt.Sprintf("%d", 1234) + case "TCP": + s.Protocol = syscall.IPPROTO_TCP + s.Port = 80 + s.Address = net.ParseIP("1.2.3.4") + s.Netmask = 0xFFFFFFFF + serviceAddress = "1.2.3.4:80" + case "UDP": + s.Protocol = syscall.IPPROTO_UDP + s.Port = 53 + s.Address = net.ParseIP("2.3.4.5") + serviceAddress = "2.3.4.5:53" + } + + err := i.NewService(&s) + assert.NoError(t, err) + checkService(t, true, protocol, RoundRobin, serviceAddress) + + s.SchedName = "" + for j, fwdMethod := range fwdMethods { + d1 := Destination{ + AddressFamily: nl.FAMILY_V4, + Address: net.ParseIP("10.1.1.2"), + Port: 5000, + Weight: 1, + ConnectionFlags: fwdMethod, + } + + realAddress := "10.1.1.2:5000" + err := i.NewDestination(&s, &d1) + assert.NoError(t, err) + checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j]) + d2 := Destination{ + AddressFamily: nl.FAMILY_V4, + Address: net.ParseIP("10.1.1.3"), + Port: 5000, + Weight: 1, + ConnectionFlags: fwdMethod, + } + + realAddress = "10.1.1.3:5000" + err = i.NewDestination(&s, &d2) + assert.NoError(t, err) + checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j]) + + d3 := Destination{ + AddressFamily: nl.FAMILY_V4, + Address: net.ParseIP("10.1.1.4"), + Port: 5000, + Weight: 1, + ConnectionFlags: fwdMethod, + } + + realAddress = "10.1.1.4:5000" + err = i.NewDestination(&s, &d3) + assert.NoError(t, err) + checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j]) + + for m, updateFwdMethod := range fwdMethods { + if updateFwdMethod == fwdMethod { + continue + } + d1.ConnectionFlags = updateFwdMethod + realAddress = "10.1.1.2:5000" + err = i.UpdateDestination(&s, &d1) + assert.NoError(t, err) + checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m]) + + d2.ConnectionFlags = updateFwdMethod + realAddress = "10.1.1.3:5000" + err = i.UpdateDestination(&s, &d2) + assert.NoError(t, err) + checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m]) + + d3.ConnectionFlags = updateFwdMethod + realAddress = "10.1.1.4:5000" + err = i.UpdateDestination(&s, &d3) + assert.NoError(t, err) + checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m]) + } + + err = i.DelDestination(&s, &d1) + assert.NoError(t, err) + err = i.DelDestination(&s, &d2) + assert.NoError(t, err) + err = i.DelDestination(&s, &d3) + assert.NoError(t, err) + } + } +} diff --git a/libnetwork/ipvs/netlink.go b/libnetwork/ipvs/netlink.go new file mode 100644 index 0000000000..ab6ec6e91f --- /dev/null +++ b/libnetwork/ipvs/netlink.go @@ -0,0 +1,227 @@ +// +build linux + +package ipvs + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "sync" + "syscall" + "unsafe" + + "github.com/vishvananda/netlink/nl" + "github.com/vishvananda/netns" +) + +var ( + native = nl.NativeEndian() + ipvsFamily int + ipvsOnce sync.Once +) + +type genlMsgHdr struct { + cmd uint8 + version uint8 + reserved uint16 +} + +type ipvsFlags struct { + flags uint32 + mask uint32 +} + +func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) { + return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0])) +} + +func (hdr *genlMsgHdr) Serialize() []byte { + return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:] +} + +func (hdr *genlMsgHdr) Len() int { + return int(unsafe.Sizeof(*hdr)) +} + +func (f *ipvsFlags) Serialize() []byte { + return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:] +} + +func (f *ipvsFlags) Len() int { + return int(unsafe.Sizeof(*f)) +} + +func setup() { + ipvsOnce.Do(func() { + var err error + ipvsFamily, err = getIPVSFamily() + if err != nil { + panic("could not get ipvs family") + } + }) +} + +func fillService(s *Service) nl.NetlinkRequestData { + cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil) + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily)) + if s.FWMark != 0 { + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark)) + } else { + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol)) + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address)) + + // Port needs to be in network byte order. + portBuf := new(bytes.Buffer) + binary.Write(portBuf, binary.BigEndian, s.Port) + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes()) + } + + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName)) + if s.PEName != "" { + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName)) + } + + f := &ipvsFlags{ + flags: s.Flags, + mask: 0xFFFFFFFF, + } + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize()) + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout)) + nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask)) + return cmdAttr +} + +func fillDestinaton(d *Destination) nl.NetlinkRequestData { + cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil) + + nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address)) + // Port needs to be in network byte order. + portBuf := new(bytes.Buffer) + binary.Write(portBuf, binary.BigEndian, d.Port) + nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes()) + + nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask)) + nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight))) + nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold)) + nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold)) + + return cmdAttr +} + +func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error { + req := newIPVSRequest(cmd) + req.AddData(fillService(s)) + + if d != nil { + req.AddData(fillDestinaton(d)) + } + + if _, err := execute(i.sock, req, 0); err != nil { + return err + } + + return nil +} + +func getIPVSFamily() (int, error) { + sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC) + if err != nil { + return 0, err + } + + req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily) + req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS"))) + + msgs, err := execute(sock, req, 0) + if err != nil { + return 0, err + } + + for _, m := range msgs { + hdr := deserializeGenlMsg(m) + attrs, err := nl.ParseRouteAttr(m[hdr.Len():]) + if err != nil { + return 0, err + } + + for _, attr := range attrs { + switch int(attr.Attr.Type) { + case genlCtrlAttrFamilyID: + return int(native.Uint16(attr.Value[0:2])), nil + } + } + } + + return 0, fmt.Errorf("no family id in the netlink response") +} + +func rawIPData(ip net.IP) []byte { + family := nl.GetIPFamily(ip) + if family == nl.FAMILY_V4 { + return ip.To4() + } + + return ip +} + +func newIPVSRequest(cmd uint8) *nl.NetlinkRequest { + return newGenlRequest(ipvsFamily, cmd) +} + +func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest { + req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK) + req.AddData(&genlMsgHdr{cmd: cmd, version: 1}) + return req +} + +func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) { + var ( + err error + ) + + if err := s.Send(req); err != nil { + return nil, err + } + + pid, err := s.GetPid() + if err != nil { + return nil, err + } + + var res [][]byte + +done: + for { + msgs, err := s.Receive() + if err != nil { + return nil, err + } + for _, m := range msgs { + if m.Header.Seq != req.Seq { + return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq) + } + if m.Header.Pid != pid { + return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid) + } + if m.Header.Type == syscall.NLMSG_DONE { + break done + } + if m.Header.Type == syscall.NLMSG_ERROR { + error := int32(native.Uint32(m.Data[0:4])) + if error == 0 { + break done + } + return nil, syscall.Errno(-error) + } + if resType != 0 && m.Header.Type != resType { + continue + } + res = append(res, m.Data) + if m.Header.Flags&syscall.NLM_F_MULTI == 0 { + break done + } + } + } + return res, nil +}