From 93b5073a7d529abc3efdc5faf0fdc4ab8189bd5c Mon Sep 17 00:00:00 2001 From: Alessandro Boch Date: Mon, 6 Jun 2016 18:17:10 -0700 Subject: [PATCH] Overlay driver to support network layer encryption Signed-off-by: Alessandro Boch --- libnetwork/agent.go | 59 +++ libnetwork/discoverapi/discoverapi.go | 24 + libnetwork/drivers/overlay/encryption.go | 578 +++++++++++++++++++++++ libnetwork/drivers/overlay/joinleave.go | 12 + libnetwork/drivers/overlay/ov_network.go | 74 ++- libnetwork/drivers/overlay/overlay.go | 39 +- libnetwork/drivers/overlay/peerdb.go | 12 +- 7 files changed, 781 insertions(+), 17 deletions(-) create mode 100644 libnetwork/drivers/overlay/encryption.go diff --git a/libnetwork/agent.go b/libnetwork/agent.go index f1d50c8908..cdfe0a783f 100644 --- a/libnetwork/agent.go +++ b/libnetwork/agent.go @@ -3,10 +3,12 @@ package libnetwork //go:generate protoc -I.:Godeps/_workspace/src/github.com/gogo/protobuf --gogo_out=import_path=github.com/docker/libnetwork,Mgogoproto/gogo.proto=github.com/gogo/protobuf/gogoproto:. agent.proto import ( + "encoding/hex" "fmt" "net" "os" "sort" + "strconv" "github.com/Sirupsen/logrus" "github.com/docker/go-events" @@ -72,6 +74,8 @@ func resolveAddr(addrOrInterface string) (string, error) { } func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error { + drvEnc := discoverapi.DriverEncryptionUpdate{} + // Find the new key and add it to the key ring a := c.agent for _, key := range keys { @@ -86,6 +90,10 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error { if key.Subsystem == "networking:gossip" { a.networkDB.SetKey(key.Key) } + if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ { + drvEnc.Key = hex.EncodeToString(key.Key) + drvEnc.Tag = strconv.FormatUint(key.LamportTime, 10) + } break } } @@ -103,6 +111,10 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error { if cKey.Subsystem == "networking:gossip" { deleted = cKey.Key } + if cKey.Subsystem == "networking:gossip" /*"networking:ipsec"*/ { + drvEnc.Prune = hex.EncodeToString(cKey.Key) + drvEnc.PruneTag = strconv.FormatUint(cKey.LamportTime, 10) + } c.keys = append(c.keys[:i], c.keys[i+1:]...) break } @@ -115,9 +127,25 @@ func (c *controller) handleKeyChange(keys []*types.EncryptionKey) error { break } } + for _, key := range c.keys { + if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ { + drvEnc.Primary = hex.EncodeToString(key.Key) + drvEnc.PrimaryTag = strconv.FormatUint(key.LamportTime, 10) + break + } + } if len(deleted) > 0 { a.networkDB.RemoveKey(deleted) } + + c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool { + err := driver.DiscoverNew(discoverapi.EncryptionKeysUpdate, drvEnc) + if err != nil { + logrus.Warnf("Failed to update datapath keys in driver %s: %v", name, err) + } + return false + }) + return nil } @@ -170,6 +198,8 @@ func (c *controller) agentInit(bindAddrOrInterface string) error { return nil } + drvEnc := discoverapi.DriverEncryptionConfig{} + // sort the keys by lamport time sort.Sort(ByTime(c.keys)) @@ -178,6 +208,10 @@ func (c *controller) agentInit(bindAddrOrInterface string) error { if key.Subsystem == "networking:gossip" { gossipkey = append(gossipkey, key.Key) } + if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ { + drvEnc.Keys = append(drvEnc.Keys, hex.EncodeToString(key.Key)) + drvEnc.Tags = append(drvEnc.Tags, strconv.FormatUint(key.LamportTime, 10)) + } } bindAddr, err := resolveAddr(bindAddrOrInterface) @@ -206,6 +240,15 @@ func (c *controller) agentInit(bindAddrOrInterface string) error { } go c.handleTableEvents(ch, c.handleEpTableEvent) + + c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool { + err := driver.DiscoverNew(discoverapi.EncryptionKeysConfig, drvEnc) + if err != nil { + logrus.Warnf("Failed to set datapath keys in driver %s: %v", name, err) + } + return false + }) + return nil } @@ -226,6 +269,22 @@ func (c *controller) agentDriverNotify(d driverapi.Driver) { Address: c.agent.bindAddr, Self: true, }) + + drvEnc := discoverapi.DriverEncryptionConfig{} + for _, key := range c.keys { + if key.Subsystem == "networking:gossip" /*"networking:ipsec"*/ { + drvEnc.Keys = append(drvEnc.Keys, hex.EncodeToString(key.Key)) + drvEnc.Tags = append(drvEnc.Tags, strconv.FormatUint(key.LamportTime, 10)) + } + } + c.drvRegistry.WalkDrivers(func(name string, driver driverapi.Driver, capability driverapi.Capability) bool { + err := driver.DiscoverNew(discoverapi.EncryptionKeysConfig, drvEnc) + if err != nil { + logrus.Warnf("Failed to set datapath keys in driver %s: %v", name, err) + } + return false + }) + } func (c *controller) agentClose() { diff --git a/libnetwork/discoverapi/discoverapi.go b/libnetwork/discoverapi/discoverapi.go index eeacc3204e..080424a182 100644 --- a/libnetwork/discoverapi/discoverapi.go +++ b/libnetwork/discoverapi/discoverapi.go @@ -18,6 +18,10 @@ const ( NodeDiscovery = iota + 1 // DatastoreConfig represents an add/remove datastore event DatastoreConfig + // EncryptionKeysConfig represents the initial key(s) for performing datapath encryption + EncryptionKeysConfig + // EncryptionKeysUpdate represents an update to the datapath encryption key(s) + EncryptionKeysUpdate ) // NodeDiscoveryData represents the structure backing the node discovery data json string @@ -33,3 +37,23 @@ type DatastoreConfigData struct { Address string Config interface{} } + +// DriverEncryptionConfig contains the initial datapath encryption key(s) +// Key in first position is the primary key, the one to be used in tx. +// Original key and tag types are []byte and uint64 +type DriverEncryptionConfig struct { + Keys []string + Tags []string +} + +// DriverEncryptionUpdate carries an update to the encryption key(s) as: +// a new key and/or set a primary key and/or a removal of an existing key. +// Original key and tag types are []byte and uint64 +type DriverEncryptionUpdate struct { + Key string + Tag string + Primary string + PrimaryTag string + Prune string + PruneTag string +} diff --git a/libnetwork/drivers/overlay/encryption.go b/libnetwork/drivers/overlay/encryption.go new file mode 100644 index 0000000000..fc82ac3700 --- /dev/null +++ b/libnetwork/drivers/overlay/encryption.go @@ -0,0 +1,578 @@ +package overlay + +import ( + "bytes" + "encoding/hex" + "fmt" + "net" + "sync" + "syscall" + + log "github.com/Sirupsen/logrus" + "github.com/docker/libnetwork/iptables" + "github.com/docker/libnetwork/types" + "github.com/vishvananda/netlink" + "strconv" +) + +const ( + mark = uint32(0xD0C4E3) + timeout = 30 +) + +const ( + forward = iota + 1 + reverse + bidir +) + +type key struct { + value []byte + tag uint32 +} + +func (k *key) String() string { + return fmt.Sprintf("(key: %s, tag: 0x%x)", hex.EncodeToString(k.value)[0:5], k.tag) +} + +type spi struct { + forward int + reverse int +} + +func (s *spi) String() string { + return fmt.Sprintf("SPI(FWD: 0x%x, REV: 0x%x)", uint32(s.forward), uint32(s.reverse)) +} + +type encrMap struct { + nodes map[string][]*spi + sync.Mutex +} + +func (e *encrMap) String() string { + e.Lock() + defer e.Unlock() + b := new(bytes.Buffer) + for k, v := range e.nodes { + b.WriteString("\n") + b.WriteString(k) + b.WriteString(":") + b.WriteString("[") + for _, s := range v { + b.WriteString(s.String()) + b.WriteString(",") + } + b.WriteString("]") + + } + return b.String() +} + +func (d *driver) checkEncryption(nid string, rIP net.IP, vxlanID uint32, isLocal, add bool) error { + log.Infof("checkEncryption(%s, %v, %d, %t)", nid[0:7], rIP, vxlanID, isLocal) + + n := d.network(nid) + if n == nil || !n.secure { + return nil + } + + if len(d.keys) == 0 { + return types.ForbiddenErrorf("encryption key is not present") + } + + lIP := types.GetMinimalIP(net.ParseIP(d.bindAddress)) + nodes := map[string]net.IP{} + + switch { + case isLocal: + if err := d.peerDbNetworkWalk(nid, func(pKey *peerKey, pEntry *peerEntry) bool { + if !lIP.Equal(pEntry.vtep) { + nodes[pEntry.vtep.String()] = types.GetMinimalIP(pEntry.vtep) + } + return false + }); err != nil { + log.Warnf("Failed to retrieve list of participating nodes in overlay network %s: %v", nid[0:5], err) + } + default: + if len(d.network(nid).endpoints) > 0 { + nodes[rIP.String()] = types.GetMinimalIP(rIP) + } + } + + log.Debugf("List of nodes: %s", nodes) + + if add { + for _, rIP := range nodes { + if err := setupEncryption(lIP, rIP, vxlanID, d.secMap, d.keys); err != nil { + log.Warnf("Failed to program network encryption between %s and %s: %v", lIP, rIP, err) + } + } + } else { + if len(nodes) == 0 { + if err := removeEncryption(lIP, rIP, d.secMap); err != nil { + log.Warnf("Failed to remove network encryption between %s and %s: %v", lIP, rIP, err) + } + } + } + + return nil +} + +func setupEncryption(localIP, remoteIP net.IP, vni uint32, em *encrMap, keys []*key) error { + log.Infof("Programming encryption for vxlan %d between %s and %s", vni, localIP, remoteIP) + rIPs := remoteIP.String() + + indices := make([]*spi, 0, len(keys)) + + err := programMangle(vni, true) + if err != nil { + log.Warn(err) + } + + for i, k := range keys { + spis := &spi{buildSPI(localIP, remoteIP, k.tag), buildSPI(remoteIP, localIP, k.tag)} + dir := reverse + if i == 0 { + dir = bidir + } + fSA, rSA, err := programSA(localIP, remoteIP, spis, k, dir, true) + if err != nil { + log.Warn(err) + } + indices = append(indices, spis) + if i != 0 { + continue + } + err = programSP(fSA, rSA, true) + if err != nil { + log.Warn(err) + } + } + + em.Lock() + em.nodes[rIPs] = indices + em.Unlock() + + return nil +} + +func removeEncryption(localIP, remoteIP net.IP, em *encrMap) error { + em.Lock() + indices, ok := em.nodes[remoteIP.String()] + em.Unlock() + if !ok { + return nil + } + for i, idxs := range indices { + dir := reverse + if i == 0 { + dir = bidir + } + fSA, rSA, err := programSA(localIP, remoteIP, idxs, nil, dir, false) + if err != nil { + log.Warn(err) + } + if i != 0 { + continue + } + err = programSP(fSA, rSA, false) + if err != nil { + log.Warn(err) + } + } + return nil +} + +func programMangle(vni uint32, add bool) (err error) { + var ( + p = strconv.FormatUint(uint64(vxlanPort), 10) + c = fmt.Sprintf("0>>22&0x3C@12&0xFFFFFF00=%d", int(vni)<<8) + m = strconv.FormatUint(uint64(mark), 10) + chain = "OUTPUT" + rule = []string{"-p", "udp", "--dport", p, "-m", "u32", "--u32", c, "-j", "MARK", "--set-mark", m} + a = "-A" + action = "install" + ) + + if add == iptables.Exists(iptables.Mangle, chain, rule...) { + return + } + + if !add { + a = "-D" + action = "remove" + } + + if err = iptables.RawCombinedOutput(append([]string{"-t", string(iptables.Mangle), a, chain}, rule...)...); err != nil { + log.Warnf("could not %s mangle rule: %v", action, err) + } + + return +} + +func programSA(localIP, remoteIP net.IP, spi *spi, k *key, dir int, add bool) (fSA *netlink.XfrmState, rSA *netlink.XfrmState, err error) { + var ( + crypt *netlink.XfrmStateAlgo + action = "Removing" + xfrmProgram = netlink.XfrmStateDel + ) + + if add { + action = "Adding" + xfrmProgram = netlink.XfrmStateAdd + crypt = &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: k.value} + } + + if dir&reverse > 0 { + rSA = &netlink.XfrmState{ + Src: remoteIP, + Dst: localIP, + Proto: netlink.XFRM_PROTO_ESP, + Spi: spi.reverse, + Mode: netlink.XFRM_MODE_TRANSPORT, + } + if add { + rSA.Crypt = crypt + } + + exists, err := saExists(rSA) + if err != nil { + exists = !add + } + + if add != exists { + log.Infof("%s: rSA{%s}", action, rSA) + if err := xfrmProgram(rSA); err != nil { + log.Warnf("Failed %s rSA{%s}: %v", action, rSA, err) + } + } + } + + if dir&forward > 0 { + fSA = &netlink.XfrmState{ + Src: localIP, + Dst: remoteIP, + Proto: netlink.XFRM_PROTO_ESP, + Spi: spi.forward, + Mode: netlink.XFRM_MODE_TRANSPORT, + } + if add { + fSA.Crypt = crypt + } + + exists, err := saExists(fSA) + if err != nil { + exists = !add + } + + if add != exists { + log.Infof("%s fSA{%s}", action, fSA) + if err := xfrmProgram(fSA); err != nil { + log.Warnf("Failed %s fSA{%s}: %v.", action, fSA, err) + } + } + } + + return +} + +func programSP(fSA *netlink.XfrmState, rSA *netlink.XfrmState, add bool) error { + action := "Removing" + xfrmProgram := netlink.XfrmPolicyDel + if add { + action = "Adding" + xfrmProgram = netlink.XfrmPolicyAdd + } + + fullMask := net.CIDRMask(8*len(fSA.Src), 8*len(fSA.Src)) + + fPol := &netlink.XfrmPolicy{ + Src: &net.IPNet{IP: fSA.Src, Mask: fullMask}, + Dst: &net.IPNet{IP: fSA.Dst, Mask: fullMask}, + Dir: netlink.XFRM_DIR_OUT, + Proto: 17, + DstPort: 4789, + Mark: &netlink.XfrmMark{ + Value: mark, + }, + Tmpls: []netlink.XfrmPolicyTmpl{ + { + Src: fSA.Src, + Dst: fSA.Dst, + Proto: netlink.XFRM_PROTO_ESP, + Mode: netlink.XFRM_MODE_TRANSPORT, + Spi: fSA.Spi, + }, + }, + } + + exists, err := spExists(fPol) + if err != nil { + exists = !add + } + + if add != exists { + log.Infof("%s fSP{%s}", action, fPol) + if err := xfrmProgram(fPol); err != nil { + log.Warnf("%s fSP{%s}: %v", action, fPol, err) + } + } + + return nil +} + +func saExists(sa *netlink.XfrmState) (bool, error) { + _, err := netlink.XfrmStateGet(sa) + switch err { + case nil: + return true, nil + case syscall.ESRCH: + return false, nil + default: + err = fmt.Errorf("Error while checking for SA existence: %v", err) + log.Debug(err) + return false, err + } +} + +func spExists(sp *netlink.XfrmPolicy) (bool, error) { + _, err := netlink.XfrmPolicyGet(sp) + switch err { + case nil: + return true, nil + case syscall.ENOENT: + return false, nil + default: + err = fmt.Errorf("Error while checking for SP existence: %v", err) + log.Debug(err) + return false, err + } +} + +func buildSPI(src, dst net.IP, st uint32) int { + spi := int(st) + f := src[len(src)-4:] + t := dst[len(dst)-4:] + for i := 0; i < 4; i++ { + spi = spi ^ (int(f[i])^int(t[3-i]))< %rSA0, +rSA2, +fSA1, +fSP1/-fSP0, -fSA0, + * Half state: rSA0, rSA1, rSA2, fSA1, fSP1 + * Steady state: rSA1, rSA2, fSA1, fSP1 + *********************************************************/ + +// Spis and keys are sorted in such away the one in position 0 is the primary +func updateNodeKey(lIP, rIP net.IP, idxs []*spi, curKeys []*key, newIdx, priIdx, delIdx int) []*spi { + log.Infof("Updating keys for node: %s (%d,%d,%d)", rIP, newIdx, priIdx, delIdx) + + spis := idxs + log.Infof("Current: %v", spis) + + // add new + if newIdx != -1 { + spis = append(spis, &spi{ + forward: buildSPI(lIP, rIP, curKeys[newIdx].tag), + reverse: buildSPI(rIP, lIP, curKeys[newIdx].tag), + }) + } + + if delIdx != -1 { + // %rSA0 + rSA0 := &netlink.XfrmState{ + Src: rIP, + Dst: lIP, + Proto: netlink.XFRM_PROTO_ESP, + Spi: spis[delIdx].reverse, + Mode: netlink.XFRM_MODE_TRANSPORT, + Crypt: &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[delIdx].value}, + Limits: netlink.XfrmStateLimits{TimeSoft: timeout}, + } + log.Infof("Updating rSA0{%s}", rSA0) + if err := netlink.XfrmStateUpdate(rSA0); err != nil { + log.Warnf("Failed to update rSA0{%s}: %v", rSA0, err) + } + } + + if newIdx > -1 { + // +RSA2 + programSA(lIP, rIP, spis[newIdx], curKeys[newIdx], reverse, true) + } + + if priIdx > 0 { + // +fSA1 + fSA1, _, _ := programSA(lIP, rIP, spis[priIdx], curKeys[priIdx], forward, true) + + // +fSP1, -fSP0 + fullMask := net.CIDRMask(8*len(fSA1.Src), 8*len(fSA1.Src)) + fSP1 := &netlink.XfrmPolicy{ + Src: &net.IPNet{IP: fSA1.Src, Mask: fullMask}, + Dst: &net.IPNet{IP: fSA1.Dst, Mask: fullMask}, + Dir: netlink.XFRM_DIR_OUT, + Proto: 17, + DstPort: 4789, + Mark: &netlink.XfrmMark{ + Value: mark, + }, + Tmpls: []netlink.XfrmPolicyTmpl{ + { + Src: fSA1.Src, + Dst: fSA1.Dst, + Proto: netlink.XFRM_PROTO_ESP, + Mode: netlink.XFRM_MODE_TRANSPORT, + Spi: fSA1.Spi, + }, + }, + } + log.Infof("Updating fSP{%s}", fSP1) + if err := netlink.XfrmPolicyUpdate(fSP1); err != nil { + log.Warnf("Failed to update fSP{%s}: %v", fSP1, err) + } + + // -fSA0 + fSA0 := &netlink.XfrmState{ + Src: lIP, + Dst: rIP, + Proto: netlink.XFRM_PROTO_ESP, + Spi: spis[0].forward, + Mode: netlink.XFRM_MODE_TRANSPORT, + Crypt: &netlink.XfrmStateAlgo{Name: "cbc(aes)", Key: curKeys[0].value}, + Limits: netlink.XfrmStateLimits{TimeHard: timeout}, + } + log.Infof("Removing fSA0{%s}", fSA0) + if err := netlink.XfrmStateUpdate(fSA0); err != nil { + log.Warnf("Failed to remove fSA0{%s}: %v", fSA0, err) + } + } + + // swap + if priIdx > 0 { + swp := spis[0] + spis[0] = spis[priIdx] + spis[priIdx] = swp + } + // prune + if delIdx != -1 { + if delIdx == 0 { + delIdx = priIdx + } + spis = append(spis[:delIdx], spis[delIdx+1:]...) + } + + log.Infof("Updated: %v", spis) + + return spis +} + +func parseEncryptionKey(value, tag string) (*key, error) { + var ( + k *key + err error + ) + if value == "" { + return nil, nil + } + k = &key{} + if k.value, err = hex.DecodeString(value); err != nil { + return nil, types.BadRequestErrorf("failed to decode key (%s): %v", value, err) + } + t, err := strconv.ParseUint(tag, 10, 64) + if err != nil { + return nil, types.BadRequestErrorf("failed to decode tag (%s): %v", tag, err) + } + k.tag = uint32(t) + return k, nil +} diff --git a/libnetwork/drivers/overlay/joinleave.go b/libnetwork/drivers/overlay/joinleave.go index 9e513feaaa..a88f25e047 100644 --- a/libnetwork/drivers/overlay/joinleave.go +++ b/libnetwork/drivers/overlay/joinleave.go @@ -27,6 +27,10 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, return fmt.Errorf("could not find endpoint with id %s", eid) } + if n.secure && len(d.keys) == 0 { + return fmt.Errorf("cannot join secure network: encryption keys not present") + } + s := n.getSubnetforIP(ep.addr) if s == nil { return fmt.Errorf("could not find subnet for endpoint %s", eid) @@ -106,6 +110,10 @@ func (d *driver) Join(nid, eid string, sboxKey string, jinfo driverapi.JoinInfo, d.peerDbAdd(nid, eid, ep.addr.IP, ep.addr.Mask, ep.mac, net.ParseIP(d.bindAddress), true) + if err := d.checkEncryption(nid, nil, n.vxlanID(s), true, true); err != nil { + log.Warn(err) + } + buf, err := proto.Marshal(&PeerRecord{ EndpointIP: ep.addr.String(), EndpointMAC: ep.mac.String(), @@ -197,5 +205,9 @@ func (d *driver) Leave(nid, eid string) error { n.leaveSandbox() + if err := d.checkEncryption(nid, nil, 0, true, false); err != nil { + log.Warn(err) + } + return nil } diff --git a/libnetwork/drivers/overlay/ov_network.go b/libnetwork/drivers/overlay/ov_network.go index d1cf940491..bbaa4fed21 100644 --- a/libnetwork/drivers/overlay/ov_network.go +++ b/libnetwork/drivers/overlay/ov_network.go @@ -61,6 +61,7 @@ type network struct { initEpoch int initErr error subnets []*subnet + secure bool sync.Mutex } @@ -109,6 +110,9 @@ func (d *driver) CreateNetwork(id string, option map[string]interface{}, nInfo d vnis = append(vnis, uint32(vni)) } } + if _, ok := optMap["secure"]; ok { + n.secure = true + } } // If we are getting vnis from libnetwork, either we get for @@ -162,7 +166,18 @@ func (d *driver) DeleteNetwork(nid string) error { d.deleteNetwork(nid) - return n.releaseVxlanID() + vnis, err := n.releaseVxlanID() + if err != nil { + return err + } + + if n.secure { + for _, vni := range vnis { + programMangle(vni, false) + } + } + + return nil } func (d *driver) ProgramExternalConnectivity(nid, eid string, options map[string]interface{}) error { @@ -618,6 +633,8 @@ func (n *network) KeyPrefix() []string { } func (n *network) Value() []byte { + m := map[string]interface{}{} + netJSON := []*subnetJSON{} for _, s := range n.subnets { @@ -630,10 +647,17 @@ func (n *network) Value() []byte { } b, err := json.Marshal(netJSON) - if err != nil { return []byte{} } + + m["secure"] = n.secure + m["subnets"] = netJSON + b, err = json.Marshal(m) + if err != nil { + return []byte{} + } + return b } @@ -655,18 +679,38 @@ func (n *network) Skip() bool { } func (n *network) SetValue(value []byte) error { - var newNet bool - netJSON := []*subnetJSON{} + var ( + m map[string]interface{} + newNet bool + isMap = true + netJSON = []*subnetJSON{} + ) - err := json.Unmarshal(value, &netJSON) - if err != nil { - return err + if err := json.Unmarshal(value, &m); err != nil { + err := json.Unmarshal(value, &netJSON) + if err != nil { + return err + } + isMap = false } if len(n.subnets) == 0 { newNet = true } + if isMap { + if val, ok := m["secure"]; ok { + n.secure = val.(bool) + } + bytes, err := json.Marshal(m["subnets"]) + if err != nil { + return err + } + if err := json.Unmarshal(bytes, &netJSON); err != nil { + return err + } + } + for _, sj := range netJSON { subnetIPstr := sj.SubnetIP gwIPstr := sj.GwIP @@ -705,9 +749,9 @@ func (n *network) writeToStore() error { return n.driver.store.PutObjectAtomic(n) } -func (n *network) releaseVxlanID() error { +func (n *network) releaseVxlanID() ([]uint32, error) { if len(n.subnets) == 0 { - return nil + return nil, nil } if n.driver.store != nil { @@ -715,22 +759,24 @@ func (n *network) releaseVxlanID() error { if err == datastore.ErrKeyModified || err == datastore.ErrKeyNotFound { // In both the above cases we can safely assume that the key has been removed by some other // instance and so simply get out of here - return nil + return nil, nil } - return fmt.Errorf("failed to delete network to vxlan id map: %v", err) + return nil, fmt.Errorf("failed to delete network to vxlan id map: %v", err) } } - + var vnis []uint32 for _, s := range n.subnets { if n.driver.vxlanIdm != nil { - n.driver.vxlanIdm.Release(uint64(n.vxlanID(s))) + vni := n.vxlanID(s) + vnis = append(vnis, vni) + n.driver.vxlanIdm.Release(uint64(vni)) } n.setVxlanID(s, 0) } - return nil + return vnis, nil } func (n *network) obtainVxlanID(s *subnet) error { diff --git a/libnetwork/drivers/overlay/overlay.go b/libnetwork/drivers/overlay/overlay.go index cfdebb5072..f8766dd50f 100644 --- a/libnetwork/drivers/overlay/overlay.go +++ b/libnetwork/drivers/overlay/overlay.go @@ -37,12 +37,14 @@ type driver struct { neighIP string config map[string]interface{} peerDb peerNetworkMap + secMap *encrMap serfInstance *serf.Serf networks networkTable store datastore.DataStore vxlanIdm *idm.Idm once sync.Once joinOnce sync.Once + keys []*key sync.Mutex } @@ -51,12 +53,12 @@ func Init(dc driverapi.DriverCallback, config map[string]interface{}) error { c := driverapi.Capability{ DataScope: datastore.GlobalScope, } - d := &driver{ networks: networkTable{}, peerDb: peerNetworkMap{ mp: map[string]*peerMap{}, }, + secMap: &encrMap{nodes: map[string][]*spi{}}, config: config, } @@ -209,6 +211,7 @@ func (d *driver) pushLocalEndpointEvent(action, nid, eid string) { // DiscoverNew is a notification for a new discovery event, such as a new node joining a cluster func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{}) error { + var err error switch dType { case discoverapi.NodeDiscovery: nodeData, ok := data.(discoverapi.NodeDiscoveryData) @@ -217,7 +220,6 @@ func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{}) } d.nodeJoin(nodeData.Address, nodeData.Self) case discoverapi.DatastoreConfig: - var err error if d.store != nil { return types.ForbiddenErrorf("cannot accept datastore configuration: Overlay driver has a datastore configured already") } @@ -229,6 +231,39 @@ func (d *driver) DiscoverNew(dType discoverapi.DiscoveryType, data interface{}) if err != nil { return types.InternalErrorf("failed to initialize data store: %v", err) } + case discoverapi.EncryptionKeysConfig: + encrData, ok := data.(discoverapi.DriverEncryptionConfig) + if !ok { + return fmt.Errorf("invalid encryption key notification data") + } + keys := make([]*key, 0, len(encrData.Keys)) + for i := 0; i < len(encrData.Keys); i++ { + k, err := parseEncryptionKey(encrData.Keys[i], encrData.Tags[i]) + if err != nil { + return err + } + keys = append(keys, k) + } + d.setKeys(keys) + case discoverapi.EncryptionKeysUpdate: + var newKey, delKey, priKey *key + encrData, ok := data.(discoverapi.DriverEncryptionUpdate) + if !ok { + return fmt.Errorf("invalid encryption key notification data") + } + newKey, err = parseEncryptionKey(encrData.Key, encrData.Tag) + if err != nil { + return err + } + priKey, err = parseEncryptionKey(encrData.Primary, encrData.PrimaryTag) + if err != nil { + return err + } + delKey, err = parseEncryptionKey(encrData.Prune, encrData.PruneTag) + if err != nil { + return err + } + d.updateKeys(newKey, priKey, delKey) default: } return nil diff --git a/libnetwork/drivers/overlay/peerdb.go b/libnetwork/drivers/overlay/peerdb.go index 3676136434..2c1112fc1d 100644 --- a/libnetwork/drivers/overlay/peerdb.go +++ b/libnetwork/drivers/overlay/peerdb.go @@ -5,6 +5,8 @@ import ( "net" "sync" "syscall" + + log "github.com/Sirupsen/logrus" ) const ovPeerTable = "overlay_peer_table" @@ -88,7 +90,7 @@ func (d *driver) peerDbNetworkWalk(nid string, f func(*peerKey, *peerEntry) bool for pKeyStr, pEntry := range pMap.mp { var pKey peerKey if _, err := fmt.Sscan(pKeyStr, &pKey); err != nil { - fmt.Printf("peer key scan failed: %v", err) + log.Warnf("Peer key scan on network %s failed: %v", nid, err) } if f(&pKey, &pEntry) { @@ -273,6 +275,10 @@ func (d *driver) peerAdd(nid, eid string, peerIP net.IP, peerIPMask net.IPMask, return fmt.Errorf("subnet sandbox join failed for %q: %v", s.subnetIP.String(), err) } + if err := d.checkEncryption(nid, vtep, n.vxlanID(s), false, true); err != nil { + log.Warn(err) + } + // Add neighbor entry for the peer IP if err := sbox.AddNeighbor(peerIP, peerMac, sbox.NeighborOptions().LinkName(s.vxlanName)); err != nil { return fmt.Errorf("could not add neigbor entry into the sandbox: %v", err) @@ -318,6 +324,10 @@ func (d *driver) peerDelete(nid, eid string, peerIP net.IP, peerIPMask net.IPMas return fmt.Errorf("could not delete neigbor entry into the sandbox: %v", err) } + if err := d.checkEncryption(nid, vtep, 0, false, false); err != nil { + log.Warn(err) + } + return nil }