From e4ce59b7aad649a8e329e9332c15d74f2de97302 Mon Sep 17 00:00:00 2001 From: Zhang Wei Date: Mon, 5 Oct 2015 13:53:45 +0800 Subject: [PATCH] Cleanup iptables after bridge network is removed Fixed #570 Clean unused iptables rules after bridge network is removed Signed-off-by: Zhang Wei --- libnetwork/drivers/bridge/bridge.go | 46 ++++++++++---------- libnetwork/drivers/bridge/setup_ip_tables.go | 10 ++++- libnetwork/iptables/firewalld_test.go | 2 +- libnetwork/iptables/iptables.go | 24 ++++++++-- libnetwork/iptables/iptables_test.go | 4 +- 5 files changed, 55 insertions(+), 31 deletions(-) diff --git a/libnetwork/drivers/bridge/bridge.go b/libnetwork/drivers/bridge/bridge.go index c2b9e89dcc..8959581391 100644 --- a/libnetwork/drivers/bridge/bridge.go +++ b/libnetwork/drivers/bridge/bridge.go @@ -41,6 +41,9 @@ const ( DefaultGatewayV6AuxKey = "DefaultGatewayIPv6" ) +type iptableCleanFunc func() error +type iptablesCleanFuncs []iptableCleanFunc + // configuration info for the "bridge" driver. type configuration struct { EnableIPForwarding bool @@ -92,12 +95,13 @@ type bridgeEndpoint struct { } type bridgeNetwork struct { - id string - bridge *bridgeInterface // The bridge's L3 interface - config *networkConfiguration - endpoints map[string]*bridgeEndpoint // key: endpoint id - portMapper *portmapper.PortMapper - driver *driver // The network's driver + id string + bridge *bridgeInterface // The bridge's L3 interface + config *networkConfiguration + endpoints map[string]*bridgeEndpoint // key: endpoint id + portMapper *portmapper.PortMapper + driver *driver // The network's driver + iptCleanFuncs iptablesCleanFuncs sync.Mutex } @@ -236,6 +240,10 @@ func parseErr(label, value, errString string) error { return types.BadRequestErrorf("failed to parse %s value: %v (%s)", label, value, errString) } +func (n *bridgeNetwork) registerIptCleanFunc(clean iptableCleanFunc) { + n.iptCleanFuncs = append(n.iptCleanFuncs, clean) +} + func (n *bridgeNetwork) getDriverChains() (*iptables.ChainInfo, *iptables.ChainInfo, error) { n.Lock() defer n.Unlock() @@ -602,6 +610,10 @@ func (d *driver) createNetwork(config *networkConfiguration) error { } return err } + network.registerIptCleanFunc(func() error { + nwList := d.getNetworks() + return network.isolateNetwork(nwList, false) + }) return nil } @@ -720,22 +732,6 @@ func (d *driver) DeleteNetwork(nid string) error { return err } - // In case of failures after this point, restore the network isolation rules - nwList := d.getNetworks() - defer func() { - if err != nil { - if err := n.isolateNetwork(nwList, true); err != nil { - logrus.Warnf("Failed on restoring the inter-network iptables rules on cleanup: %v", err) - } - } - }() - - // Remove inter-network communication rules. - err = n.isolateNetwork(nwList, false) - if err != nil { - return err - } - // We only delete the bridge when it's not the default bridge. This is keep the backward compatible behavior. if !config.DefaultBridge { if err := netlink.LinkDel(n.bridge.Link); err != nil { @@ -743,6 +739,12 @@ func (d *driver) DeleteNetwork(nid string) error { } } + // clean all relevant iptables rules + for _, cleanFunc := range n.iptCleanFuncs { + if errClean := cleanFunc(); errClean != nil { + logrus.Warnf("Failed to clean iptables rules for bridge network: %v", errClean) + } + } return d.storeDelete(config) } diff --git a/libnetwork/drivers/bridge/setup_ip_tables.go b/libnetwork/drivers/bridge/setup_ip_tables.go index f597f67b7f..6b1c5dcb11 100644 --- a/libnetwork/drivers/bridge/setup_ip_tables.go +++ b/libnetwork/drivers/bridge/setup_ip_tables.go @@ -68,21 +68,27 @@ func (n *bridgeNetwork) setupIPTables(config *networkConfiguration, i *bridgeInt if err = setupIPTablesInternal(config.BridgeName, maskedAddrv4, config.EnableICC, config.EnableIPMasquerade, hairpinMode, true); err != nil { return fmt.Errorf("Failed to Setup IP tables: %s", err.Error()) } + n.registerIptCleanFunc(func() error { + return setupIPTablesInternal(config.BridgeName, maskedAddrv4, config.EnableICC, config.EnableIPMasquerade, hairpinMode, false) + }) natChain, filterChain, err := n.getDriverChains() if err != nil { return fmt.Errorf("Failed to setup IP tables, cannot acquire chain info %s", err.Error()) } - err = iptables.ProgramChain(natChain, config.BridgeName, hairpinMode) + err = iptables.ProgramChain(natChain, config.BridgeName, hairpinMode, true) if err != nil { return fmt.Errorf("Failed to program NAT chain: %s", err.Error()) } - err = iptables.ProgramChain(filterChain, config.BridgeName, hairpinMode) + err = iptables.ProgramChain(filterChain, config.BridgeName, hairpinMode, true) if err != nil { return fmt.Errorf("Failed to program FILTER chain: %s", err.Error()) } + n.registerIptCleanFunc(func() error { + return iptables.ProgramChain(filterChain, config.BridgeName, hairpinMode, false) + }) n.portMapper.SetIptablesChain(filterChain, n.getNetworkBridgeName()) diff --git a/libnetwork/iptables/firewalld_test.go b/libnetwork/iptables/firewalld_test.go index 6607307564..1ac11a9adf 100644 --- a/libnetwork/iptables/firewalld_test.go +++ b/libnetwork/iptables/firewalld_test.go @@ -22,7 +22,7 @@ func TestReloaded(t *testing.T) { fwdChain, err = NewChain("FWD", Filter, false) bridgeName := "lo" - err = ProgramChain(fwdChain, bridgeName, false) + err = ProgramChain(fwdChain, bridgeName, false, true) if err != nil { t.Fatal(err) } diff --git a/libnetwork/iptables/iptables.go b/libnetwork/iptables/iptables.go index 4e24c3ac52..be7725f85d 100644 --- a/libnetwork/iptables/iptables.go +++ b/libnetwork/iptables/iptables.go @@ -95,7 +95,7 @@ func NewChain(name string, table Table, hairpinMode bool) (*ChainInfo, error) { } // ProgramChain is used to add rules to a chain -func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error { +func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode, enable bool) error { if c.Name == "" { return fmt.Errorf("Could not program chain, missing chain name.") } @@ -106,10 +106,14 @@ func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error { "-m", "addrtype", "--dst-type", "LOCAL", "-j", c.Name} - if !Exists(Nat, "PREROUTING", preroute...) { + if !Exists(Nat, "PREROUTING", preroute...) && enable { if err := c.Prerouting(Append, preroute...); err != nil { return fmt.Errorf("Failed to inject docker in PREROUTING chain: %s", err) } + } else if Exists(Nat, "PREROUTING", preroute...) && !enable { + if err := c.Prerouting(Delete, preroute...); err != nil { + return fmt.Errorf("Failed to remove docker in PREROUTING chain: %s", err) + } } output := []string{ "-m", "addrtype", @@ -118,10 +122,14 @@ func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error { if !hairpinMode { output = append(output, "!", "--dst", "127.0.0.0/8") } - if !Exists(Nat, "OUTPUT", output...) { + if !Exists(Nat, "OUTPUT", output...) && enable { if err := c.Output(Append, output...); err != nil { return fmt.Errorf("Failed to inject docker in OUTPUT chain: %s", err) } + } else if Exists(Nat, "OUTPUT", output...) && !enable { + if err := c.Output(Delete, output...); err != nil { + return fmt.Errorf("Failed to inject docker in OUTPUT chain: %s", err) + } } case Filter: if bridgeName == "" { @@ -131,13 +139,21 @@ func ProgramChain(c *ChainInfo, bridgeName string, hairpinMode bool) error { link := []string{ "-o", bridgeName, "-j", c.Name} - if !Exists(Filter, "FORWARD", link...) { + if !Exists(Filter, "FORWARD", link...) && enable { insert := append([]string{string(Insert), "FORWARD"}, link...) if output, err := Raw(insert...); err != nil { return err } else if len(output) != 0 { return fmt.Errorf("Could not create linking rule to %s/%s: %s", c.Table, c.Name, output) } + } else if Exists(Filter, "FORWARD", link...) && !enable { + del := append([]string{string(Delete), "FORWARD"}, link...) + if output, err := Raw(del...); err != nil { + return err + } else if len(output) != 0 { + return fmt.Errorf("Could not delete linking rule from %s/%s: %s", c.Table, c.Name, output) + } + } } return nil diff --git a/libnetwork/iptables/iptables_test.go b/libnetwork/iptables/iptables_test.go index 262ba18a72..2e5a2b5d25 100644 --- a/libnetwork/iptables/iptables_test.go +++ b/libnetwork/iptables/iptables_test.go @@ -22,13 +22,13 @@ func TestNewChain(t *testing.T) { bridgeName = "lo" natChain, err = NewChain(chainName, Nat, false) - err = ProgramChain(natChain, bridgeName, false) + err = ProgramChain(natChain, bridgeName, false, true) if err != nil { t.Fatal(err) } filterChain, err = NewChain(chainName, Filter, false) - err = ProgramChain(filterChain, bridgeName, false) + err = ProgramChain(filterChain, bridgeName, false, true) if err != nil { t.Fatal(err) }