1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00
moby--moby/libnetwork/resolver.go
Jana Radhakrishnan d81a91ebad Resolver sockets not flushed on default gw change
Currently when the default gw changes because of
other network connections happening in the container
the resolver sockets are not flushed. This results
in a subsequent DNS failure for external queries

A sequence of connecting the container to an overlay
network and subsequently to a bridge network without
disconnecting from any network will result in this
behaviour. This was revealed by one of the libnetwork
IT tests.

This is now fixed as part of the commit by flushing
the external query sockets when a default gw change
is detected.

Signed-off-by: Jana Radhakrishnan <mrjana@docker.com>
2016-04-10 10:40:06 -07:00

508 lines
12 KiB
Go

package libnetwork
import (
"fmt"
"math/rand"
"net"
"strings"
"sync"
"time"
log "github.com/Sirupsen/logrus"
"github.com/docker/libnetwork/iptables"
"github.com/docker/libnetwork/netutils"
"github.com/miekg/dns"
)
// Resolver represents the embedded DNS server in Docker. It operates
// by listening on container's loopback interface for DNS queries.
type Resolver interface {
// Start starts the name server for the container
Start() error
// Stop stops the name server for the container. Stopped resolver
// can be reused after running the SetupFunc again.
Stop()
// SetupFunc() provides the setup function that should be run
// in the container's network namespace.
SetupFunc() func()
// NameServer() returns the IP of the DNS resolver for the
// containers.
NameServer() string
// SetExtServers configures the external nameservers the resolver
// should use to forward queries
SetExtServers([]string)
// FlushExtServers clears the cached UDP connections to external
// nameservers
FlushExtServers()
// ResolverOptions returns resolv.conf options that should be set
ResolverOptions() []string
}
const (
resolverIP = "127.0.0.11"
dnsPort = "53"
ptrIPv4domain = ".in-addr.arpa."
ptrIPv6domain = ".ip6.arpa."
respTTL = 600
maxExtDNS = 3 //max number of external servers to try
extIOTimeout = 4 * time.Second
defaultRespSize = 512
maxConcurrent = 50
logInterval = 2 * time.Second
maxDNSID = 65536
)
type clientConn struct {
dnsID uint16
respWriter dns.ResponseWriter
}
type extDNSEntry struct {
ipStr string
extConn net.Conn
extOnce sync.Once
}
// resolver implements the Resolver interface
type resolver struct {
sb *sandbox
extDNSList [maxExtDNS]extDNSEntry
server *dns.Server
conn *net.UDPConn
tcpServer *dns.Server
tcpListen *net.TCPListener
err error
count int32
tStamp time.Time
queryLock sync.Mutex
client map[uint16]clientConn
}
func init() {
rand.Seed(time.Now().Unix())
}
// NewResolver creates a new instance of the Resolver
func NewResolver(sb *sandbox) Resolver {
return &resolver{
sb: sb,
err: fmt.Errorf("setup not done yet"),
client: make(map[uint16]clientConn),
}
}
func (r *resolver) SetupFunc() func() {
return (func() {
var err error
// DNS operates primarily on UDP
addr := &net.UDPAddr{
IP: net.ParseIP(resolverIP),
}
r.conn, err = net.ListenUDP("udp", addr)
if err != nil {
r.err = fmt.Errorf("error in opening name server socket %v", err)
return
}
laddr := r.conn.LocalAddr()
_, ipPort, _ := net.SplitHostPort(laddr.String())
// Listen on a TCP as well
tcpaddr := &net.TCPAddr{
IP: net.ParseIP(resolverIP),
}
r.tcpListen, err = net.ListenTCP("tcp", tcpaddr)
if err != nil {
r.err = fmt.Errorf("error in opening name TCP server socket %v", err)
return
}
ltcpaddr := r.tcpListen.Addr()
_, tcpPort, _ := net.SplitHostPort(ltcpaddr.String())
rules := [][]string{
{"-t", "nat", "-A", "OUTPUT", "-d", resolverIP, "-p", "udp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", laddr.String()},
{"-t", "nat", "-A", "POSTROUTING", "-s", resolverIP, "-p", "udp", "--sport", ipPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
{"-t", "nat", "-A", "OUTPUT", "-d", resolverIP, "-p", "tcp", "--dport", dnsPort, "-j", "DNAT", "--to-destination", ltcpaddr.String()},
{"-t", "nat", "-A", "POSTROUTING", "-s", resolverIP, "-p", "tcp", "--sport", tcpPort, "-j", "SNAT", "--to-source", ":" + dnsPort},
}
for _, rule := range rules {
r.err = iptables.RawCombinedOutputNative(rule...)
if r.err != nil {
return
}
}
r.err = nil
})
}
func (r *resolver) Start() error {
// make sure the resolver has been setup before starting
if r.err != nil {
return r.err
}
s := &dns.Server{Handler: r, PacketConn: r.conn}
r.server = s
go func() {
s.ActivateAndServe()
}()
tcpServer := &dns.Server{Handler: r, Listener: r.tcpListen}
r.tcpServer = tcpServer
go func() {
tcpServer.ActivateAndServe()
}()
return nil
}
func (r *resolver) FlushExtServers() {
for i := 0; i < maxExtDNS; i++ {
if r.extDNSList[i].extConn != nil {
r.extDNSList[i].extConn.Close()
}
r.extDNSList[i].extConn = nil
r.extDNSList[i].extOnce = sync.Once{}
}
}
func (r *resolver) Stop() {
r.FlushExtServers()
if r.server != nil {
r.server.Shutdown()
}
if r.tcpServer != nil {
r.tcpServer.Shutdown()
}
r.conn = nil
r.tcpServer = nil
r.err = fmt.Errorf("setup not done yet")
r.tStamp = time.Time{}
r.count = 0
r.queryLock = sync.Mutex{}
}
func (r *resolver) SetExtServers(dns []string) {
l := len(dns)
if l > maxExtDNS {
l = maxExtDNS
}
for i := 0; i < l; i++ {
r.extDNSList[i].ipStr = dns[i]
}
}
func (r *resolver) NameServer() string {
return resolverIP
}
func (r *resolver) ResolverOptions() []string {
return []string{"ndots:0"}
}
func setCommonFlags(msg *dns.Msg) {
msg.RecursionAvailable = true
}
func shuffleAddr(addr []net.IP) []net.IP {
for i := len(addr) - 1; i > 0; i-- {
r := rand.Intn(i + 1)
addr[i], addr[r] = addr[r], addr[i]
}
return addr
}
func createRespMsg(query *dns.Msg) *dns.Msg {
resp := new(dns.Msg)
resp.SetReply(query)
setCommonFlags(resp)
return resp
}
func (r *resolver) handleIPQuery(name string, query *dns.Msg, ipType int) (*dns.Msg, error) {
addr, ipv6Miss := r.sb.ResolveName(name, ipType)
if addr == nil && ipv6Miss {
// Send a reply without any Answer sections
log.Debugf("Lookup name %s present without IPv6 address", name)
resp := createRespMsg(query)
return resp, nil
}
if addr == nil {
return nil, nil
}
log.Debugf("Lookup for %s: IP %v", name, addr)
resp := createRespMsg(query)
if len(addr) > 1 {
addr = shuffleAddr(addr)
}
if ipType == netutils.IPv4 {
for _, ip := range addr {
rr := new(dns.A)
rr.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: respTTL}
rr.A = ip
resp.Answer = append(resp.Answer, rr)
}
} else {
for _, ip := range addr {
rr := new(dns.AAAA)
rr.Hdr = dns.RR_Header{Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: respTTL}
rr.AAAA = ip
resp.Answer = append(resp.Answer, rr)
}
}
return resp, nil
}
func (r *resolver) handlePTRQuery(ptr string, query *dns.Msg) (*dns.Msg, error) {
parts := []string{}
if strings.HasSuffix(ptr, ptrIPv4domain) {
parts = strings.Split(ptr, ptrIPv4domain)
} else if strings.HasSuffix(ptr, ptrIPv6domain) {
parts = strings.Split(ptr, ptrIPv6domain)
} else {
return nil, fmt.Errorf("invalid PTR query, %v", ptr)
}
host := r.sb.ResolveIP(parts[0])
if len(host) == 0 {
return nil, nil
}
log.Debugf("Lookup for IP %s: name %s", parts[0], host)
fqdn := dns.Fqdn(host)
resp := new(dns.Msg)
resp.SetReply(query)
setCommonFlags(resp)
rr := new(dns.PTR)
rr.Hdr = dns.RR_Header{Name: ptr, Rrtype: dns.TypePTR, Class: dns.ClassINET, Ttl: respTTL}
rr.Ptr = fqdn
resp.Answer = append(resp.Answer, rr)
return resp, nil
}
func truncateResp(resp *dns.Msg, maxSize int, isTCP bool) {
if !isTCP {
resp.Truncated = true
}
// trim the Answer RRs one by one till the whole message fits
// within the reply size
for resp.Len() > maxSize {
resp.Answer = resp.Answer[:len(resp.Answer)-1]
}
}
func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
var (
extConn net.Conn
resp *dns.Msg
err error
)
if query == nil || len(query.Question) == 0 {
return
}
name := query.Question[0].Name
if query.Question[0].Qtype == dns.TypeA {
resp, err = r.handleIPQuery(name, query, netutils.IPv4)
} else if query.Question[0].Qtype == dns.TypeAAAA {
resp, err = r.handleIPQuery(name, query, netutils.IPv6)
} else if query.Question[0].Qtype == dns.TypePTR {
resp, err = r.handlePTRQuery(name, query)
}
if err != nil {
log.Error(err)
return
}
proto := w.LocalAddr().Network()
maxSize := 0
if proto == "tcp" {
maxSize = dns.MaxMsgSize - 1
} else if proto == "udp" {
optRR := query.IsEdns0()
if optRR != nil {
maxSize = int(optRR.UDPSize())
}
if maxSize < defaultRespSize {
maxSize = defaultRespSize
}
}
if resp != nil {
if resp.Len() > maxSize {
truncateResp(resp, maxSize, proto == "tcp")
}
} else {
for i := 0; i < maxExtDNS; i++ {
extDNS := &r.extDNSList[i]
if extDNS.ipStr == "" {
break
}
extConnect := func() {
addr := fmt.Sprintf("%s:%d", extDNS.ipStr, 53)
extConn, err = net.DialTimeout(proto, addr, extIOTimeout)
}
// For udp clients connection is persisted to reuse for further queries.
// Accessing extDNS.extConn be a race here between go rouines. Hence the
// connection setup is done in a Once block and fetch the extConn again
extConn = extDNS.extConn
if extConn == nil || proto == "tcp" {
if proto == "udp" {
extDNS.extOnce.Do(func() {
r.sb.execFunc(extConnect)
extDNS.extConn = extConn
})
extConn = extDNS.extConn
} else {
r.sb.execFunc(extConnect)
}
if err != nil {
log.Debugf("Connect failed, %s", err)
continue
}
}
// If two go routines are executing in parralel one will
// block on the Once.Do and in case of error connecting
// to the external server it will end up with a nil err
// but extConn also being nil.
if extConn == nil {
continue
}
log.Debugf("Query %s[%d] from %s, forwarding to %s:%s", name, query.Question[0].Qtype,
extConn.LocalAddr().String(), proto, extDNS.ipStr)
// Timeout has to be set for every IO operation.
extConn.SetDeadline(time.Now().Add(extIOTimeout))
co := &dns.Conn{Conn: extConn}
// forwardQueryStart stores required context to mux multiple client queries over
// one connection; and limits the number of outstanding concurrent queries.
if r.forwardQueryStart(w, query) == false {
old := r.tStamp
r.tStamp = time.Now()
if r.tStamp.Sub(old) > logInterval {
log.Errorf("More than %v concurrent queries from %s", maxConcurrent, w.LocalAddr().String())
}
continue
}
defer func() {
if proto == "tcp" {
co.Close()
}
}()
err = co.WriteMsg(query)
if err != nil {
r.forwardQueryEnd(w, query)
log.Debugf("Send to DNS server failed, %s", err)
continue
}
resp, err = co.ReadMsg()
if err != nil {
r.forwardQueryEnd(w, query)
log.Debugf("Read from DNS server failed, %s", err)
continue
}
// Retrieves the context for the forwarded query and returns the client connection
// to send the reply to
w = r.forwardQueryEnd(w, resp)
if w == nil {
continue
}
resp.Compress = true
break
}
if resp == nil || w == nil {
return
}
}
err = w.WriteMsg(resp)
if err != nil {
log.Errorf("error writing resolver resp, %s", err)
}
}
func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool {
proto := w.LocalAddr().Network()
dnsID := uint16(rand.Intn(maxDNSID))
cc := clientConn{
dnsID: msg.Id,
respWriter: w,
}
r.queryLock.Lock()
defer r.queryLock.Unlock()
if r.count == maxConcurrent {
return false
}
r.count++
switch proto {
case "tcp":
break
case "udp":
for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
_, ok = r.client[dnsID]
}
log.Debugf("client dns id %v, changed id %v", msg.Id, dnsID)
r.client[dnsID] = cc
msg.Id = dnsID
default:
log.Errorf("Invalid protocol..")
return false
}
return true
}
func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
var (
cc clientConn
ok bool
)
proto := w.LocalAddr().Network()
r.queryLock.Lock()
defer r.queryLock.Unlock()
if r.count == 0 {
log.Errorf("Invalid concurrent query count")
} else {
r.count--
}
switch proto {
case "tcp":
break
case "udp":
if cc, ok = r.client[msg.Id]; ok == false {
log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
return nil
}
delete(r.client, msg.Id)
msg.Id = cc.dnsID
w = cc.respWriter
default:
log.Errorf("Invalid protocol")
return nil
}
return w
}