diff --git a/libnetwork/drivers/overlay/ov_network.go b/libnetwork/drivers/overlay/ov_network.go index 3fbfccf007..11314170b1 100644 --- a/libnetwork/drivers/overlay/ov_network.go +++ b/libnetwork/drivers/overlay/ov_network.go @@ -696,6 +696,12 @@ func (n *network) initSandbox(restore bool) error { var nlSock *nl.NetlinkSocket sbox.InvokeFunc(func() { nlSock, err = nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_NEIGH) + if err != nil { + return + } + // set the receive timeout to not remain stuck on the RecvFrom if the fd gets closed + tv := syscall.NsecToTimeval(soTimeout.Nanoseconds()) + err = nlSock.SetReceiveTimeout(&tv) }) n.setNetlinkSocket(nlSock) @@ -721,6 +727,11 @@ func (n *network) watchMiss(nlSock *nl.NetlinkSocket) { // The netlink socket got closed, simply exit to not leak this goroutine return } + // When the receive timeout expires the receive will return EAGAIN + if err == syscall.EAGAIN { + // we continue here to avoid spam for timeouts + continue + } logrus.Errorf("Failed to receive from netlink: %v ", err) continue } diff --git a/libnetwork/drivers/overlay/overlay_test.go b/libnetwork/drivers/overlay/overlay_test.go index 6d2127311d..75c89da6bb 100644 --- a/libnetwork/drivers/overlay/overlay_test.go +++ b/libnetwork/drivers/overlay/overlay_test.go @@ -1,7 +1,9 @@ package overlay import ( + "context" "net" + "syscall" "testing" "time" @@ -12,6 +14,7 @@ import ( "github.com/docker/libnetwork/driverapi" "github.com/docker/libnetwork/netlabel" _ "github.com/docker/libnetwork/testutils" + "github.com/vishvananda/netlink/nl" ) func init() { @@ -135,3 +138,36 @@ func TestOverlayType(t *testing.T) { dt.d.Type()) } } + +// Test that the netlink socket close unblock the watchMiss to avoid deadlock +func TestNetlinkSocket(t *testing.T) { + // This is the same code used by the overlay driver to create the netlink interface + // for the watch miss + nlSock, err := nl.Subscribe(syscall.NETLINK_ROUTE, syscall.RTNLGRP_NEIGH) + if err != nil { + t.Fatal() + } + // set the receive timeout to not remain stuck on the RecvFrom if the fd gets closed + tv := syscall.NsecToTimeval(soTimeout.Nanoseconds()) + err = nlSock.SetReceiveTimeout(&tv) + if err != nil { + t.Fatal() + } + n := &network{id: "testnetid"} + ch := make(chan error) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func() { + n.watchMiss(nlSock) + ch <- nil + }() + time.Sleep(5 * time.Second) + nlSock.Close() + select { + case <-ch: + case <-ctx.Done(): + { + t.Fatalf("Timeout expired") + } + } +} diff --git a/libnetwork/ipvs/ipvs.go b/libnetwork/ipvs/ipvs.go index ebcdd808c3..effbb716eb 100644 --- a/libnetwork/ipvs/ipvs.go +++ b/libnetwork/ipvs/ipvs.go @@ -5,12 +5,19 @@ package ipvs import ( "net" "syscall" + "time" "fmt" + "github.com/vishvananda/netlink/nl" "github.com/vishvananda/netns" ) +const ( + netlinkRecvSocketsTimeout = 3 * time.Second + netlinkSendSocketTimeout = 30 * time.Second +) + // Service defines an IPVS service in its entirety. type Service struct { // Virtual service address. @@ -82,6 +89,15 @@ func New(path string) (*Handle, error) { if err != nil { return nil, err } + // Add operation timeout to avoid deadlocks + tv := syscall.NsecToTimeval(netlinkSendSocketTimeout.Nanoseconds()) + if err := sock.SetSendTimeout(&tv); err != nil { + return nil, err + } + tv = syscall.NsecToTimeval(netlinkRecvSocketsTimeout.Nanoseconds()) + if err := sock.SetReceiveTimeout(&tv); err != nil { + return nil, err + } return &Handle{sock: sock}, nil } diff --git a/libnetwork/ipvs/netlink.go b/libnetwork/ipvs/netlink.go index 2089283d14..c062a1789d 100644 --- a/libnetwork/ipvs/netlink.go +++ b/libnetwork/ipvs/netlink.go @@ -203,10 +203,6 @@ func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest { } func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) { - var ( - err error - ) - if err := s.Send(req); err != nil { return nil, err } @@ -222,6 +218,13 @@ done: for { msgs, err := s.Receive() if err != nil { + if s.GetFd() == -1 { + return nil, fmt.Errorf("Socket got closed on receive") + } + if err == syscall.EAGAIN { + // timeout fired + continue + } return nil, err } for _, m := range msgs {