2016-05-24 16:12:16 -04:00
// +build linux
package ipvs
import (
"bytes"
"encoding/binary"
"fmt"
"net"
2016-05-31 02:55:51 -04:00
"os/exec"
"strings"
2016-05-24 16:12:16 -04:00
"sync"
2017-02-01 17:55:39 -05:00
"sync/atomic"
2016-05-24 16:12:16 -04:00
"syscall"
"unsafe"
2016-05-31 02:55:51 -04:00
"github.com/Sirupsen/logrus"
2016-05-24 16:12:16 -04:00
"github.com/vishvananda/netlink/nl"
"github.com/vishvananda/netns"
)
var (
native = nl . NativeEndian ( )
ipvsFamily int
ipvsOnce sync . Once
)
type genlMsgHdr struct {
cmd uint8
version uint8
reserved uint16
}
type ipvsFlags struct {
flags uint32
mask uint32
}
func deserializeGenlMsg ( b [ ] byte ) ( hdr * genlMsgHdr ) {
return ( * genlMsgHdr ) ( unsafe . Pointer ( & b [ 0 : unsafe . Sizeof ( * hdr ) ] [ 0 ] ) )
}
func ( hdr * genlMsgHdr ) Serialize ( ) [ ] byte {
return ( * ( * [ unsafe . Sizeof ( * hdr ) ] byte ) ( unsafe . Pointer ( hdr ) ) ) [ : ]
}
func ( hdr * genlMsgHdr ) Len ( ) int {
return int ( unsafe . Sizeof ( * hdr ) )
}
func ( f * ipvsFlags ) Serialize ( ) [ ] byte {
return ( * ( * [ unsafe . Sizeof ( * f ) ] byte ) ( unsafe . Pointer ( f ) ) ) [ : ]
}
func ( f * ipvsFlags ) Len ( ) int {
return int ( unsafe . Sizeof ( * f ) )
}
func setup ( ) {
ipvsOnce . Do ( func ( ) {
var err error
2016-05-31 02:55:51 -04:00
if out , err := exec . Command ( "modprobe" , "-va" , "ip_vs" ) . CombinedOutput ( ) ; err != nil {
2016-06-10 11:32:39 -04:00
logrus . Warnf ( "Running modprobe ip_vs failed with message: `%s`, error: %v" , strings . TrimSpace ( string ( out ) ) , err )
2016-05-31 02:55:51 -04:00
}
2016-05-24 16:12:16 -04:00
ipvsFamily , err = getIPVSFamily ( )
if err != nil {
2016-11-01 00:26:14 -04:00
logrus . Error ( "Could not get ipvs family information from the kernel. It is possible that ipvs is not enabled in your kernel. Native loadbalancing will not work until this is fixed." )
2016-05-24 16:12:16 -04:00
}
} )
}
func fillService ( s * Service ) nl . NetlinkRequestData {
cmdAttr := nl . NewRtAttr ( ipvsCmdAttrService , nil )
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrAddressFamily , nl . Uint16Attr ( s . AddressFamily ) )
if s . FWMark != 0 {
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrFWMark , nl . Uint32Attr ( s . FWMark ) )
} else {
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrProtocol , nl . Uint16Attr ( s . Protocol ) )
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrAddress , rawIPData ( s . Address ) )
// Port needs to be in network byte order.
portBuf := new ( bytes . Buffer )
binary . Write ( portBuf , binary . BigEndian , s . Port )
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrPort , portBuf . Bytes ( ) )
}
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrSchedName , nl . ZeroTerminated ( s . SchedName ) )
if s . PEName != "" {
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrPEName , nl . ZeroTerminated ( s . PEName ) )
}
f := & ipvsFlags {
flags : s . Flags ,
mask : 0xFFFFFFFF ,
}
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrFlags , f . Serialize ( ) )
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrTimeout , nl . Uint32Attr ( s . Timeout ) )
nl . NewRtAttrChild ( cmdAttr , ipvsSvcAttrNetmask , nl . Uint32Attr ( s . Netmask ) )
return cmdAttr
}
func fillDestinaton ( d * Destination ) nl . NetlinkRequestData {
cmdAttr := nl . NewRtAttr ( ipvsCmdAttrDest , nil )
nl . NewRtAttrChild ( cmdAttr , ipvsDestAttrAddress , rawIPData ( d . Address ) )
// Port needs to be in network byte order.
portBuf := new ( bytes . Buffer )
binary . Write ( portBuf , binary . BigEndian , d . Port )
nl . NewRtAttrChild ( cmdAttr , ipvsDestAttrPort , portBuf . Bytes ( ) )
nl . NewRtAttrChild ( cmdAttr , ipvsDestAttrForwardingMethod , nl . Uint32Attr ( d . ConnectionFlags & ConnectionFlagFwdMask ) )
nl . NewRtAttrChild ( cmdAttr , ipvsDestAttrWeight , nl . Uint32Attr ( uint32 ( d . Weight ) ) )
nl . NewRtAttrChild ( cmdAttr , ipvsDestAttrUpperThreshold , nl . Uint32Attr ( d . UpperThreshold ) )
nl . NewRtAttrChild ( cmdAttr , ipvsDestAttrLowerThreshold , nl . Uint32Attr ( d . LowerThreshold ) )
return cmdAttr
}
func ( i * Handle ) doCmd ( s * Service , d * Destination , cmd uint8 ) error {
req := newIPVSRequest ( cmd )
2017-02-01 17:55:39 -05:00
req . Seq = atomic . AddUint32 ( & i . seq , 1 )
2016-05-24 16:12:16 -04:00
req . AddData ( fillService ( s ) )
if d != nil {
req . AddData ( fillDestinaton ( d ) )
}
if _ , err := execute ( i . sock , req , 0 ) ; err != nil {
return err
}
return nil
}
func getIPVSFamily ( ) ( int , error ) {
sock , err := nl . GetNetlinkSocketAt ( netns . None ( ) , netns . None ( ) , syscall . NETLINK_GENERIC )
if err != nil {
return 0 , err
}
2017-02-03 12:00:59 -05:00
defer sock . Close ( )
2016-05-24 16:12:16 -04:00
req := newGenlRequest ( genlCtrlID , genlCtrlCmdGetFamily )
req . AddData ( nl . NewRtAttr ( genlCtrlAttrFamilyName , nl . ZeroTerminated ( "IPVS" ) ) )
msgs , err := execute ( sock , req , 0 )
if err != nil {
return 0 , err
}
for _ , m := range msgs {
hdr := deserializeGenlMsg ( m )
attrs , err := nl . ParseRouteAttr ( m [ hdr . Len ( ) : ] )
if err != nil {
return 0 , err
}
for _ , attr := range attrs {
switch int ( attr . Attr . Type ) {
case genlCtrlAttrFamilyID :
return int ( native . Uint16 ( attr . Value [ 0 : 2 ] ) ) , nil
}
}
}
return 0 , fmt . Errorf ( "no family id in the netlink response" )
}
func rawIPData ( ip net . IP ) [ ] byte {
family := nl . GetIPFamily ( ip )
if family == nl . FAMILY_V4 {
return ip . To4 ( )
}
return ip
}
func newIPVSRequest ( cmd uint8 ) * nl . NetlinkRequest {
return newGenlRequest ( ipvsFamily , cmd )
}
func newGenlRequest ( familyID int , cmd uint8 ) * nl . NetlinkRequest {
req := nl . NewNetlinkRequest ( familyID , syscall . NLM_F_ACK )
req . AddData ( & genlMsgHdr { cmd : cmd , version : 1 } )
return req
}
func execute ( s * nl . NetlinkSocket , req * nl . NetlinkRequest , resType uint16 ) ( [ ] [ ] byte , error ) {
var (
err error
)
if err := s . Send ( req ) ; err != nil {
return nil , err
}
pid , err := s . GetPid ( )
if err != nil {
return nil , err
}
var res [ ] [ ] byte
done :
for {
msgs , err := s . Receive ( )
if err != nil {
return nil , err
}
for _ , m := range msgs {
if m . Header . Seq != req . Seq {
2017-02-01 17:55:39 -05:00
continue
2016-05-24 16:12:16 -04:00
}
if m . Header . Pid != pid {
return nil , fmt . Errorf ( "Wrong pid %d, expected %d" , m . Header . Pid , pid )
}
if m . Header . Type == syscall . NLMSG_DONE {
break done
}
if m . Header . Type == syscall . NLMSG_ERROR {
error := int32 ( native . Uint32 ( m . Data [ 0 : 4 ] ) )
if error == 0 {
break done
}
return nil , syscall . Errno ( - error )
}
if resType != 0 && m . Header . Type != resType {
continue
}
res = append ( res , m . Data )
if m . Header . Flags & syscall . NLM_F_MULTI == 0 {
break done
}
}
}
return res , nil
}