diff --git a/libnetwork/etchosts/etchosts.go b/libnetwork/etchosts/etchosts.go index 88e6b63e70..9095b483b4 100644 --- a/libnetwork/etchosts/etchosts.go +++ b/libnetwork/etchosts/etchosts.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/ioutil" + "os" "regexp" ) @@ -65,6 +66,45 @@ func Build(path, IP, hostname, domainname string, extraContent []Record) error { return ioutil.WriteFile(path, content.Bytes(), 0644) } +// Add adds an arbitrary number of Records to an already existing /etc/hosts file +func Add(path string, recs []Record) error { + f, err := os.Open(path) + if err != nil { + return err + } + + content := bytes.NewBuffer(nil) + + _, err = content.ReadFrom(f) + if err != nil { + return err + } + + for _, r := range recs { + if _, err := r.WriteTo(content); err != nil { + return err + } + } + + return ioutil.WriteFile(path, content.Bytes(), 0644) +} + +// Delete deletes an arbitrary number of Records already existing in /etc/hosts file +func Delete(path string, recs []Record) error { + old, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + regexpStr := fmt.Sprintf("\\S*\\t%s\\n", regexp.QuoteMeta(recs[0].Hosts)) + for _, r := range recs[1:] { + regexpStr = regexpStr + "|" + fmt.Sprintf("\\S*\\t%s\\n", regexp.QuoteMeta(r.Hosts)) + } + + var re = regexp.MustCompile(regexpStr) + return ioutil.WriteFile(path, re.ReplaceAll(old, []byte("")), 0644) +} + // Update all IP addresses where hostname matches. // path is path to host file // IP is new IP address diff --git a/libnetwork/etchosts/etchosts_test.go b/libnetwork/etchosts/etchosts_test.go index 8c8b87c016..ce17d57455 100644 --- a/libnetwork/etchosts/etchosts_test.go +++ b/libnetwork/etchosts/etchosts_test.go @@ -134,3 +134,82 @@ func TestUpdate(t *testing.T) { t.Fatalf("Expected to find '%s' got '%s'", expected, content) } } + +func TestAdd(t *testing.T) { + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + + err = Build(file.Name(), "", "", "", nil) + if err != nil { + t.Fatal(err) + } + + if err := Add(file.Name(), []Record{ + Record{ + Hosts: "testhostname", + IP: "2.2.2.2", + }, + }); err != nil { + t.Fatal(err) + } + + content, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatal(err) + } + + if expected := "2.2.2.2\ttesthostname\n"; !bytes.Contains(content, []byte(expected)) { + t.Fatalf("Expected to find '%s' got '%s'", expected, content) + } +} + +func TestDelete(t *testing.T) { + file, err := ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(file.Name()) + + err = Build(file.Name(), "", "", "", nil) + if err != nil { + t.Fatal(err) + } + + if err := Add(file.Name(), []Record{ + Record{ + Hosts: "testhostname1", + IP: "1.1.1.1", + }, + Record{ + Hosts: "testhostname2", + IP: "2.2.2.2", + }, + }); err != nil { + t.Fatal(err) + } + + if err := Delete(file.Name(), []Record{ + Record{ + Hosts: "testhostname1", + IP: "1.1.1.1", + }, + }); err != nil { + t.Fatal(err) + } + + content, err := ioutil.ReadFile(file.Name()) + if err != nil { + t.Fatal(err) + } + + if expected := "2.2.2.2\ttesthostname2\n"; !bytes.Contains(content, []byte(expected)) { + t.Fatalf("Expected to find '%s' got '%s'", expected, content) + } + + if expected := "1.1.1.1\ttesthostname1\n"; bytes.Contains(content, []byte(expected)) { + t.Fatalf("Did not expect to find '%s' got '%s'", expected, content) + } +}