diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index a29e850293..7af1850cf6 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -49,8 +49,14 @@ const ( defaultRespSize = 512 maxConcurrent = 50 logInterval = 2 * time.Second + maxDNSID = 65536 ) +type clientConn struct { + dnsID uint16 + respWriter dns.ResponseWriter +} + type extDNSEntry struct { ipStr string extConn net.Conn @@ -69,6 +75,7 @@ type resolver struct { count int32 tStamp time.Time queryLock sync.Mutex + client map[uint16]clientConn } func init() { @@ -78,8 +85,9 @@ func init() { // NewResolver creates a new instance of the Resolver func NewResolver(sb *sandbox) Resolver { return &resolver{ - sb: sb, - err: fmt.Errorf("setup not done yet"), + sb: sb, + err: fmt.Errorf("setup not done yet"), + client: make(map[uint16]clientConn), } } @@ -375,7 +383,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { extConn.SetDeadline(time.Now().Add(extIOTimeout)) co := &dns.Conn{Conn: extConn} - if r.concurrentQueryInc() == false { + // forwardQueryStart stores required context to mux multiple client queries over + // one connection; and limits the number of outstanding concurrent queries. + if r.forwardQueryStart(w, query) == false { old := r.tStamp r.tStamp = time.Now() if r.tStamp.Sub(old) > logInterval { @@ -391,18 +401,25 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { }() err = co.WriteMsg(query) if err != nil { - r.concurrentQueryDec() + r.forwardQueryEnd(w, query) log.Debugf("Send to DNS server failed, %s", err) continue } resp, err = co.ReadMsg() - r.concurrentQueryDec() if err != nil { + r.forwardQueryEnd(w, query) log.Debugf("Read from DNS server failed, %s", err) continue } + // Retrieves the context for the forwarded query and returns the client connection + // to send the reply to + w = r.forwardQueryEnd(w, resp) + if w == nil { + continue + } + resp.Compress = true break } @@ -418,22 +435,71 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { } } -func (r *resolver) concurrentQueryInc() bool { +func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool { + proto := w.LocalAddr().Network() + dnsID := uint16(rand.Intn(maxDNSID)) + + cc := clientConn{ + dnsID: msg.Id, + respWriter: w, + } + r.queryLock.Lock() defer r.queryLock.Unlock() + if r.count == maxConcurrent { return false } r.count++ + + switch proto { + case "tcp": + break + case "udp": + for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) { + _, ok = r.client[dnsID] + } + log.Debugf("client dns id %v, changed id %v", msg.Id, dnsID) + r.client[dnsID] = cc + msg.Id = dnsID + default: + log.Errorf("Invalid protocol..") + return false + } + return true } -func (r *resolver) concurrentQueryDec() bool { +func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter { + var ( + cc clientConn + ok bool + ) + proto := w.LocalAddr().Network() + r.queryLock.Lock() defer r.queryLock.Unlock() + if r.count == 0 { - return false + log.Errorf("Invalid concurrent query count") + } else { + r.count-- } - r.count-- - return true + + switch proto { + case "tcp": + break + case "udp": + if cc, ok = r.client[msg.Id]; ok == false { + log.Debugf("Can't retrieve client context for dns id %v", msg.Id) + return nil + } + delete(r.client, msg.Id) + msg.Id = cc.dnsID + w = cc.respWriter + default: + log.Errorf("Invalid protocol") + return nil + } + return w }