diff --git a/libnetwork/resolver.go b/libnetwork/resolver.go index becee88345..b94903e835 100644 --- a/libnetwork/resolver.go +++ b/libnetwork/resolver.go @@ -292,6 +292,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { extConn net.Conn resp *dns.Msg err error + writer dns.ResponseWriter ) if query == nil || len(query.Question) == 0 { @@ -329,7 +330,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { if resp.Len() > maxSize { truncateResp(resp, maxSize, proto == "tcp") } + writer = w } else { + queryID := query.Id for i := 0; i < maxExtDNS; i++ { extDNS := &r.extDNSList[i] if extDNS.ipStr == "" { @@ -375,7 +378,7 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { // forwardQueryStart stores required context to mux multiple client queries over // one connection; and limits the number of outstanding concurrent queries. - if r.forwardQueryStart(w, query) == false { + if r.forwardQueryStart(w, query, queryID) == false { old := r.tStamp r.tStamp = time.Now() if r.tStamp.Sub(old) > logInterval { @@ -405,32 +408,33 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) { // Retrieves the context for the forwarded query and returns the client connection // to send the reply to - w = r.forwardQueryEnd(w, resp) - if w == nil { + writer = r.forwardQueryEnd(w, resp) + if writer == nil { continue } resp.Compress = true break } - - if resp == nil || w == nil { + if resp == nil || writer == nil { return } } - err = w.WriteMsg(resp) - if err != nil { + if writer == nil { + return + } + if err = writer.WriteMsg(resp); err != nil { log.Errorf("error writing resolver resp, %s", err) } } -func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool { +func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg, queryID uint16) bool { proto := w.LocalAddr().Network() dnsID := uint16(rand.Intn(maxDNSID)) cc := clientConn{ - dnsID: msg.Id, + dnsID: queryID, respWriter: w, }