From d96e94897e82240e04bc4c69097e68e6ac83ed14 Mon Sep 17 00:00:00 2001 From: Jana Radhakrishnan Date: Fri, 15 May 2015 21:01:53 +0000 Subject: [PATCH] Fix DNS entry update issue When an update is done to the container resolv.conf file and it was inheriting host entries, then we should not re-read the host entries when the container leaves and re-joins the endpoint. Signed-off-by: Jana Radhakrishnan --- libnetwork/endpoint.go | 80 +++++++++++++++++++++++++++++++++-- libnetwork/libnetwork_test.go | 60 +++++++++++++++++++++++--- 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/libnetwork/endpoint.go b/libnetwork/endpoint.go index 5630bed231..702193da24 100644 --- a/libnetwork/endpoint.go +++ b/libnetwork/endpoint.go @@ -1,12 +1,15 @@ package libnetwork import ( + "bytes" "io/ioutil" "os" + "path" "path/filepath" "sync" "github.com/Sirupsen/logrus" + "github.com/docker/docker/pkg/ioutils" "github.com/docker/libnetwork/driverapi" "github.com/docker/libnetwork/netutils" "github.com/docker/libnetwork/pkg/etchosts" @@ -513,7 +516,7 @@ func (ep *endpoint) updateParentHosts() error { return nil } -func (ep *endpoint) setupDNS() error { +func (ep *endpoint) updateDNS(resolvConf []byte) error { ep.Lock() container := ep.container network := ep.network @@ -523,6 +526,77 @@ func (ep *endpoint) setupDNS() error { return ErrNoContainer } + hashFile := container.config.resolvConfPath + ".hash" + oldHash, err := ioutil.ReadFile(hashFile) + if err != nil { + if !os.IsNotExist(err) { + return err + } + + oldHash = []byte{} + } + + resolvBytes, err := ioutil.ReadFile(container.config.resolvConfPath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + } + + curHash, err := ioutils.HashData(bytes.NewReader(resolvBytes)) + if err != nil { + return err + } + + if string(oldHash) != "" && curHash != string(oldHash) { + // Seems the user has changed the container resolv.conf since the last time + // we checked so return without doing anything. + return nil + } + + // replace any localhost/127.* and remove IPv6 nameservers if IPv6 disabled. + resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, network.enableIPv6) + + newHash, err := ioutils.HashData(bytes.NewReader(resolvConf)) + if err != nil { + return err + } + + // for atomic updates to these files, use temporary files with os.Rename: + dir := path.Dir(container.config.resolvConfPath) + tmpHashFile, err := ioutil.TempFile(dir, "hash") + if err != nil { + return err + } + tmpResolvFile, err := ioutil.TempFile(dir, "resolv") + if err != nil { + return err + } + + // write the updates to the temp files + if err = ioutil.WriteFile(tmpHashFile.Name(), []byte(newHash), 0644); err != nil { + return err + } + if err = ioutil.WriteFile(tmpResolvFile.Name(), resolvConf, 0644); err != nil { + return err + } + + // rename the temp files for atomic replace + if err = os.Rename(tmpHashFile.Name(), hashFile); err != nil { + return err + } + return os.Rename(tmpResolvFile.Name(), container.config.resolvConfPath) +} + +func (ep *endpoint) setupDNS() error { + ep.Lock() + container := ep.container + ep.Unlock() + + if container == nil { + return ErrNoContainer + } + if container.config.resolvConfPath == "" { container.config.resolvConfPath = defaultPrefix + "/" + container.id + "/resolv.conf" } @@ -556,9 +630,7 @@ func (ep *endpoint) setupDNS() error { return resolvconf.Build(container.config.resolvConfPath, dnsList, dnsSearchList) } - // replace any localhost/127.* but always discard IPv6 entries for now. - resolvConf, _ = resolvconf.FilterResolvDNS(resolvConf, network.enableIPv6) - return ioutil.WriteFile(ep.container.config.resolvConfPath, resolvConf, 0644) + return ep.updateDNS(resolvConf) } // EndpointOptionGeneric function returns an option setter for a Generic option defined diff --git a/libnetwork/libnetwork_test.go b/libnetwork/libnetwork_test.go index 2e443ee28f..9cc1def240 100644 --- a/libnetwork/libnetwork_test.go +++ b/libnetwork/libnetwork_test.go @@ -1017,13 +1017,17 @@ func TestEnableIPv6(t *testing.T) { } } -func TestNoEnableIPv6(t *testing.T) { +func TestResolvConf(t *testing.T) { if !netutils.IsRunningInContainer() { defer netutils.SetupTestNetNS(t)() } - tmpResolvConf := []byte("search pommesfrites.fr\nnameserver 12.34.56.78\nnameserver 2001:4860:4860::8888") - expectedResolvConf := []byte("search pommesfrites.fr\nnameserver 12.34.56.78\n") + tmpResolvConf1 := []byte("search pommesfrites.fr\nnameserver 12.34.56.78\nnameserver 2001:4860:4860::8888") + expectedResolvConf1 := []byte("search pommesfrites.fr\nnameserver 12.34.56.78\n") + tmpResolvConf2 := []byte("search pommesfrites.fr\nnameserver 112.34.56.78\nnameserver 2001:4860:4860::8888") + expectedResolvConf2 := []byte("search pommesfrites.fr\nnameserver 112.34.56.78\n") + tmpResolvConf3 := []byte("search pommesfrites.fr\nnameserver 113.34.56.78\n") + //take a copy of resolv.conf for restoring after test completes resolvConfSystem, err := ioutil.ReadFile("/etc/resolv.conf") if err != nil { @@ -1046,7 +1050,7 @@ func TestNoEnableIPv6(t *testing.T) { t.Fatal(err) } - if err := ioutil.WriteFile("/etc/resolv.conf", tmpResolvConf, 0644); err != nil { + if err := ioutil.WriteFile("/etc/resolv.conf", tmpResolvConf1, 0644); err != nil { t.Fatal(err) } @@ -1069,13 +1073,57 @@ func TestNoEnableIPv6(t *testing.T) { t.Fatal(err) } - if !bytes.Equal(content, expectedResolvConf) { - t.Fatalf("Expected %s, Got %s", string(expectedResolvConf), string(content)) + if !bytes.Equal(content, expectedResolvConf1) { + t.Fatalf("Expected %s, Got %s", string(expectedResolvConf1), string(content)) } + err = ep1.Leave(containerID) if err != nil { t.Fatal(err) } + + if err := ioutil.WriteFile("/etc/resolv.conf", tmpResolvConf2, 0644); err != nil { + t.Fatal(err) + } + + _, err = ep1.Join(containerID, + libnetwork.JoinOptionResolvConfPath(resolvConfPath)) + if err != nil { + t.Fatal(err) + } + + content, err = ioutil.ReadFile(resolvConfPath) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(content, expectedResolvConf2) { + t.Fatalf("Expected %s, Got %s", string(expectedResolvConf2), string(content)) + } + + if err := ioutil.WriteFile(resolvConfPath, tmpResolvConf3, 0644); err != nil { + t.Fatal(err) + } + + err = ep1.Leave(containerID) + if err != nil { + t.Fatal(err) + } + + _, err = ep1.Join(containerID, + libnetwork.JoinOptionResolvConfPath(resolvConfPath)) + if err != nil { + t.Fatal(err) + } + + content, err = ioutil.ReadFile(resolvConfPath) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(content, tmpResolvConf3) { + t.Fatalf("Expected %s, Got %s", string(tmpResolvConf3), string(content)) + } } var (