Merge pull request #1191 from mrjana/ipvs

Add IPVS netlink support
This commit is contained in:
Madhu Venugopal 2016-05-25 02:58:13 -07:00
commit d7946ec4d8
4 changed files with 791 additions and 0 deletions

View File

@ -0,0 +1,130 @@
// +build linux
package ipvs
const (
genlCtrlID = 0x10
)
// GENL control commands
const (
genlCtrlCmdUnspec uint8 = iota
genlCtrlCmdNewFamily
genlCtrlCmdDelFamily
genlCtrlCmdGetFamily
)
// GENL family attributes
const (
genlCtrlAttrUnspec int = iota
genlCtrlAttrFamilyID
genlCtrlAttrFamilyName
)
// IPVS genl commands
const (
ipvsCmdUnspec uint8 = iota
ipvsCmdNewService
ipvsCmdSetService
ipvsCmdDelService
ipvsCmdGetService
ipvsCmdNewDest
ipvsCmdSetDest
ipvsCmdDelDest
ipvsCmdGetDest
ipvsCmdNewDaemon
ipvsCmdDelDaemon
ipvsCmdGetDaemon
ipvsCmdSetConfig
ipvsCmdGetConfig
ipvsCmdSetInfo
ipvsCmdGetInfo
ipvsCmdZero
ipvsCmdFlush
)
// Attributes used in the first level of commands
const (
ipvsCmdAttrUnspec int = iota
ipvsCmdAttrService
ipvsCmdAttrDest
ipvsCmdAttrDaemon
ipvsCmdAttrTimeoutTCP
ipvsCmdAttrTimeoutTCPFin
ipvsCmdAttrTimeoutUDP
)
// Attributes used to describe a service. Used inside nested attribute
// ipvsCmdAttrService
const (
ipvsSvcAttrUnspec int = iota
ipvsSvcAttrAddressFamily
ipvsSvcAttrProtocol
ipvsSvcAttrAddress
ipvsSvcAttrPort
ipvsSvcAttrFWMark
ipvsSvcAttrSchedName
ipvsSvcAttrFlags
ipvsSvcAttrTimeout
ipvsSvcAttrNetmask
ipvsSvcAttrStats
ipvsSvcAttrPEName
)
// Attributes used to describe a destination (real server). Used
// inside nested attribute ipvsCmdAttrDest.
const (
ipvsDestAttrUnspec int = iota
ipvsDestAttrAddress
ipvsDestAttrPort
ipvsDestAttrForwardingMethod
ipvsDestAttrWeight
ipvsDestAttrUpperThreshold
ipvsDestAttrLowerThreshold
ipvsDestAttrActiveConnections
ipvsDestAttrInactiveConnections
ipvsDestAttrPersistentConnections
ipvsDestAttrStats
)
// Destination forwarding methods
const (
// ConnectionFlagFwdmask indicates the mask in the connection
// flags which is used by forwarding method bits.
ConnectionFlagFwdMask = 0x0007
// ConnectionFlagMasq is used for masquerade forwarding method.
ConnectionFlagMasq = 0x0000
// ConnectionFlagLocalNode is used for local node forwarding
// method.
ConnectionFlagLocalNode = 0x0001
// ConnectionFlagTunnel is used for tunnel mode forwarding
// method.
ConnectionFlagTunnel = 0x0002
// ConnectionFlagDirectRoute is used for direct routing
// forwarding method.
ConnectionFlagDirectRoute = 0x0003
)
const (
// RoundRobin distributes jobs equally amongst the available
// real servers.
RoundRobin = "rr"
// LeastConnection assigns more jobs to real servers with
// fewer active jobs.
LeastConnection = "lc"
// DestinationHashing assigns jobs to servers through looking
// up a statically assigned hash table by their destination IP
// addresses.
DestinationHashing = "dh"
// SourceHashing assigns jobs to servers through looking up
// a statically assigned hash table by their source IP
// addresses.
SourceHashing = "sh"
)

113
libnetwork/ipvs/ipvs.go Normal file
View File

@ -0,0 +1,113 @@
// +build linux
package ipvs
import (
"net"
"syscall"
"github.com/vishvananda/netlink/nl"
"github.com/vishvananda/netns"
)
// Service defines an IPVS service in its entirety.
type Service struct {
// Virtual service address.
Address net.IP
Protocol uint16
Port uint16
FWMark uint32 // Firewall mark of the service.
// Virtual service options.
SchedName string
Flags uint32
Timeout uint32
Netmask uint32
AddressFamily uint16
PEName string
}
// Destination defines an IPVS destination (real server) in its
// entirety.
type Destination struct {
Address net.IP
Port uint16
Weight int
ConnectionFlags uint32
AddressFamily uint16
UpperThreshold uint32
LowerThreshold uint32
}
// Handle provides a namespace specific ipvs handle to program ipvs
// rules.
type Handle struct {
sock *nl.NetlinkSocket
}
// New provides a new ipvs handle in the namespace pointed to by the
// passed path. It will return a valid handle or an error in case an
// error occured while creating the handle.
func New(path string) (*Handle, error) {
setup()
n := netns.None()
if path != "" {
var err error
n, err = netns.GetFromPath(path)
if err != nil {
return nil, err
}
}
sock, err := nl.GetNetlinkSocketAt(n, netns.None(), syscall.NETLINK_GENERIC)
if err != nil {
n.Close()
return nil, err
}
return &Handle{sock: sock}, nil
}
// Close closes the ipvs handle. The handle is invalid after Close
// returns.
func (i *Handle) Close() {
if i.sock != nil {
i.sock.Close()
}
}
// NewService creates a new ipvs service in the passed handle.
func (i *Handle) NewService(s *Service) error {
return i.doCmd(s, nil, ipvsCmdNewService)
}
// UpdateService updates an already existing service in the passed
// handle.
func (i *Handle) UpdateService(s *Service) error {
return i.doCmd(s, nil, ipvsCmdSetService)
}
// DelService deletes an already existing service in the passed
// handle.
func (i *Handle) DelService(s *Service) error {
return i.doCmd(s, nil, ipvsCmdDelService)
}
// NewDestination creates an new real server in the passed ipvs
// service which should already be existing in the passed handle.
func (i *Handle) NewDestination(s *Service, d *Destination) error {
return i.doCmd(s, d, ipvsCmdNewDest)
}
// UpdateDestination updates an already existing real server in the
// passed ipvs service in the passed handle.
func (i *Handle) UpdateDestination(s *Service, d *Destination) error {
return i.doCmd(s, d, ipvsCmdSetDest)
}
// DelDestination deletes an already existing real server in the
// passed ipvs service in the passed handle.
func (i *Handle) DelDestination(s *Service, d *Destination) error {
return i.doCmd(s, d, ipvsCmdDelDest)
}

View File

@ -0,0 +1,321 @@
// +build linux
package ipvs
import (
"fmt"
"net"
"os/exec"
"strings"
"syscall"
"testing"
"github.com/docker/libnetwork/testutils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
)
var (
schedMethods = []string{
RoundRobin,
LeastConnection,
DestinationHashing,
SourceHashing,
}
protocols = []string{
"TCP",
"UDP",
"FWM",
}
fwdMethods = []uint32{
ConnectionFlagMasq,
ConnectionFlagTunnel,
ConnectionFlagDirectRoute,
}
fwdMethodStrings = []string{
"Masq",
"Tunnel",
"Route",
}
)
func checkDestination(t *testing.T, checkPresent bool, protocol, serviceAddress, realAddress, fwdMethod string) {
var (
realServerStart bool
realServers []string
)
out, err := exec.Command("ipvsadm", "-Ln").CombinedOutput()
require.NoError(t, err)
for _, o := range strings.Split(string(out), "\n") {
cmpStr := serviceAddress
if protocol == "FWM" {
cmpStr = " " + cmpStr
}
if strings.Contains(o, cmpStr) {
realServerStart = true
continue
}
if realServerStart {
if !strings.Contains(o, "->") {
break
}
realServers = append(realServers, o)
}
}
for _, r := range realServers {
if strings.Contains(r, realAddress) {
parts := strings.Fields(r)
assert.Equal(t, fwdMethod, parts[2])
return
}
}
if checkPresent {
t.Fatalf("Did not find the destination %s fwdMethod %s in ipvs output", realAddress, fwdMethod)
}
}
func checkService(t *testing.T, checkPresent bool, protocol, schedMethod, serviceAddress string) {
out, err := exec.Command("ipvsadm", "-Ln").CombinedOutput()
require.NoError(t, err)
for _, o := range strings.Split(string(out), "\n") {
cmpStr := serviceAddress
if protocol == "FWM" {
cmpStr = " " + cmpStr
}
if strings.Contains(o, cmpStr) {
parts := strings.Split(o, " ")
assert.Equal(t, protocol, parts[0])
assert.Equal(t, serviceAddress, parts[2])
assert.Equal(t, schedMethod, parts[3])
if !checkPresent {
t.Fatalf("Did not expect the service %s in ipvs output", serviceAddress)
}
return
}
}
if checkPresent {
t.Fatalf("Did not find the service %s in ipvs output", serviceAddress)
}
}
func TestGetFamily(t *testing.T) {
if testutils.RunningOnCircleCI() {
t.Skipf("Skipping as not supported on CIRCLE CI kernel")
}
id, err := getIPVSFamily()
require.NoError(t, err)
assert.NotEqual(t, 0, id)
}
func TestService(t *testing.T) {
if testutils.RunningOnCircleCI() {
t.Skipf("Skipping as not supported on CIRCLE CI kernel")
}
defer testutils.SetupTestOSContext(t)()
i, err := New("")
require.NoError(t, err)
for _, protocol := range protocols {
for _, schedMethod := range schedMethods {
var serviceAddress string
s := Service{
AddressFamily: nl.FAMILY_V4,
SchedName: schedMethod,
}
switch protocol {
case "FWM":
s.FWMark = 1234
serviceAddress = fmt.Sprintf("%d", 1234)
case "TCP":
s.Protocol = syscall.IPPROTO_TCP
s.Port = 80
s.Address = net.ParseIP("1.2.3.4")
s.Netmask = 0xFFFFFFFF
serviceAddress = "1.2.3.4:80"
case "UDP":
s.Protocol = syscall.IPPROTO_UDP
s.Port = 53
s.Address = net.ParseIP("2.3.4.5")
serviceAddress = "2.3.4.5:53"
}
err := i.NewService(&s)
assert.NoError(t, err)
checkService(t, true, protocol, schedMethod, serviceAddress)
var lastMethod string
for _, updateSchedMethod := range schedMethods {
if updateSchedMethod == schedMethod {
continue
}
s.SchedName = updateSchedMethod
err = i.UpdateService(&s)
assert.NoError(t, err)
checkService(t, true, protocol, updateSchedMethod, serviceAddress)
lastMethod = updateSchedMethod
}
err = i.DelService(&s)
checkService(t, false, protocol, lastMethod, serviceAddress)
}
}
}
func createDummyInterface(t *testing.T) {
if testutils.RunningOnCircleCI() {
t.Skipf("Skipping as not supported on CIRCLE CI kernel")
}
dummy := &netlink.Dummy{
LinkAttrs: netlink.LinkAttrs{
Name: "dummy",
},
}
err := netlink.LinkAdd(dummy)
require.NoError(t, err)
dummyLink, err := netlink.LinkByName("dummy")
require.NoError(t, err)
ip, ipNet, err := net.ParseCIDR("10.1.1.1/24")
require.NoError(t, err)
ipNet.IP = ip
ipAddr := &netlink.Addr{IPNet: ipNet, Label: ""}
err = netlink.AddrAdd(dummyLink, ipAddr)
require.NoError(t, err)
}
func TestDestination(t *testing.T) {
defer testutils.SetupTestOSContext(t)()
createDummyInterface(t)
i, err := New("")
require.NoError(t, err)
for _, protocol := range []string{"TCP"} {
var serviceAddress string
s := Service{
AddressFamily: nl.FAMILY_V4,
SchedName: RoundRobin,
}
switch protocol {
case "FWM":
s.FWMark = 1234
serviceAddress = fmt.Sprintf("%d", 1234)
case "TCP":
s.Protocol = syscall.IPPROTO_TCP
s.Port = 80
s.Address = net.ParseIP("1.2.3.4")
s.Netmask = 0xFFFFFFFF
serviceAddress = "1.2.3.4:80"
case "UDP":
s.Protocol = syscall.IPPROTO_UDP
s.Port = 53
s.Address = net.ParseIP("2.3.4.5")
serviceAddress = "2.3.4.5:53"
}
err := i.NewService(&s)
assert.NoError(t, err)
checkService(t, true, protocol, RoundRobin, serviceAddress)
s.SchedName = ""
for j, fwdMethod := range fwdMethods {
d1 := Destination{
AddressFamily: nl.FAMILY_V4,
Address: net.ParseIP("10.1.1.2"),
Port: 5000,
Weight: 1,
ConnectionFlags: fwdMethod,
}
realAddress := "10.1.1.2:5000"
err := i.NewDestination(&s, &d1)
assert.NoError(t, err)
checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j])
d2 := Destination{
AddressFamily: nl.FAMILY_V4,
Address: net.ParseIP("10.1.1.3"),
Port: 5000,
Weight: 1,
ConnectionFlags: fwdMethod,
}
realAddress = "10.1.1.3:5000"
err = i.NewDestination(&s, &d2)
assert.NoError(t, err)
checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j])
d3 := Destination{
AddressFamily: nl.FAMILY_V4,
Address: net.ParseIP("10.1.1.4"),
Port: 5000,
Weight: 1,
ConnectionFlags: fwdMethod,
}
realAddress = "10.1.1.4:5000"
err = i.NewDestination(&s, &d3)
assert.NoError(t, err)
checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[j])
for m, updateFwdMethod := range fwdMethods {
if updateFwdMethod == fwdMethod {
continue
}
d1.ConnectionFlags = updateFwdMethod
realAddress = "10.1.1.2:5000"
err = i.UpdateDestination(&s, &d1)
assert.NoError(t, err)
checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m])
d2.ConnectionFlags = updateFwdMethod
realAddress = "10.1.1.3:5000"
err = i.UpdateDestination(&s, &d2)
assert.NoError(t, err)
checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m])
d3.ConnectionFlags = updateFwdMethod
realAddress = "10.1.1.4:5000"
err = i.UpdateDestination(&s, &d3)
assert.NoError(t, err)
checkDestination(t, true, protocol, serviceAddress, realAddress, fwdMethodStrings[m])
}
err = i.DelDestination(&s, &d1)
assert.NoError(t, err)
err = i.DelDestination(&s, &d2)
assert.NoError(t, err)
err = i.DelDestination(&s, &d3)
assert.NoError(t, err)
}
}
}

227
libnetwork/ipvs/netlink.go Normal file
View File

@ -0,0 +1,227 @@
// +build linux
package ipvs
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"sync"
"syscall"
"unsafe"
"github.com/vishvananda/netlink/nl"
"github.com/vishvananda/netns"
)
var (
native = nl.NativeEndian()
ipvsFamily int
ipvsOnce sync.Once
)
type genlMsgHdr struct {
cmd uint8
version uint8
reserved uint16
}
type ipvsFlags struct {
flags uint32
mask uint32
}
func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) {
return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0]))
}
func (hdr *genlMsgHdr) Serialize() []byte {
return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:]
}
func (hdr *genlMsgHdr) Len() int {
return int(unsafe.Sizeof(*hdr))
}
func (f *ipvsFlags) Serialize() []byte {
return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:]
}
func (f *ipvsFlags) Len() int {
return int(unsafe.Sizeof(*f))
}
func setup() {
ipvsOnce.Do(func() {
var err error
ipvsFamily, err = getIPVSFamily()
if err != nil {
panic("could not get ipvs family")
}
})
}
func fillService(s *Service) nl.NetlinkRequestData {
cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil)
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily))
if s.FWMark != 0 {
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark))
} else {
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol))
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address))
// Port needs to be in network byte order.
portBuf := new(bytes.Buffer)
binary.Write(portBuf, binary.BigEndian, s.Port)
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes())
}
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName))
if s.PEName != "" {
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName))
}
f := &ipvsFlags{
flags: s.Flags,
mask: 0xFFFFFFFF,
}
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize())
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout))
nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask))
return cmdAttr
}
func fillDestinaton(d *Destination) nl.NetlinkRequestData {
cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil)
nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address))
// Port needs to be in network byte order.
portBuf := new(bytes.Buffer)
binary.Write(portBuf, binary.BigEndian, d.Port)
nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes())
nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask))
nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight)))
nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold))
nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold))
return cmdAttr
}
func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error {
req := newIPVSRequest(cmd)
req.AddData(fillService(s))
if d != nil {
req.AddData(fillDestinaton(d))
}
if _, err := execute(i.sock, req, 0); err != nil {
return err
}
return nil
}
func getIPVSFamily() (int, error) {
sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC)
if err != nil {
return 0, err
}
req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily)
req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS")))
msgs, err := execute(sock, req, 0)
if err != nil {
return 0, err
}
for _, m := range msgs {
hdr := deserializeGenlMsg(m)
attrs, err := nl.ParseRouteAttr(m[hdr.Len():])
if err != nil {
return 0, err
}
for _, attr := range attrs {
switch int(attr.Attr.Type) {
case genlCtrlAttrFamilyID:
return int(native.Uint16(attr.Value[0:2])), nil
}
}
}
return 0, fmt.Errorf("no family id in the netlink response")
}
func rawIPData(ip net.IP) []byte {
family := nl.GetIPFamily(ip)
if family == nl.FAMILY_V4 {
return ip.To4()
}
return ip
}
func newIPVSRequest(cmd uint8) *nl.NetlinkRequest {
return newGenlRequest(ipvsFamily, cmd)
}
func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest {
req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK)
req.AddData(&genlMsgHdr{cmd: cmd, version: 1})
return req
}
func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) {
var (
err error
)
if err := s.Send(req); err != nil {
return nil, err
}
pid, err := s.GetPid()
if err != nil {
return nil, err
}
var res [][]byte
done:
for {
msgs, err := s.Receive()
if err != nil {
return nil, err
}
for _, m := range msgs {
if m.Header.Seq != req.Seq {
return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
}
if m.Header.Pid != pid {
return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
}
if m.Header.Type == syscall.NLMSG_DONE {
break done
}
if m.Header.Type == syscall.NLMSG_ERROR {
error := int32(native.Uint32(m.Data[0:4]))
if error == 0 {
break done
}
return nil, syscall.Errno(-error)
}
if resType != 0 && m.Header.Type != resType {
continue
}
res = append(res, m.Data)
if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
break done
}
}
}
return res, nil
}