diff --git a/graph.go b/graph.go index 606a6833ee..c54725fdb4 100644 --- a/graph.go +++ b/graph.go @@ -323,9 +323,9 @@ func (graph *Graph) ByParent() (map[string][]*Image, error) { return } if children, exists := byParent[parent.ID]; exists { - byParent[parent.ID] = []*Image{image} - } else { byParent[parent.ID] = append(children, image) + } else { + byParent[parent.ID] = []*Image{image} } }) return byParent, err diff --git a/graph_test.go b/graph_test.go index 2898fccf99..32fb0ef441 100644 --- a/graph_test.go +++ b/graph_test.go @@ -234,6 +234,45 @@ func TestDelete(t *testing.T) { assertNImages(graph, t, 1) } +func TestByParent(t *testing.T) { + archive1, _ := fakeTar() + archive2, _ := fakeTar() + archive3, _ := fakeTar() + + graph := tempGraph(t) + defer os.RemoveAll(graph.Root) + parentImage := &Image{ + ID: GenerateID(), + Comment: "parent", + Created: time.Now(), + Parent: "", + } + childImage1 := &Image{ + ID: GenerateID(), + Comment: "child1", + Created: time.Now(), + Parent: parentImage.ID, + } + childImage2 := &Image{ + ID: GenerateID(), + Comment: "child2", + Created: time.Now(), + Parent: parentImage.ID, + } + _ = graph.Register(nil, archive1, parentImage) + _ = graph.Register(nil, archive2, childImage1) + _ = graph.Register(nil, archive3, childImage2) + + byParent, err := graph.ByParent() + if err != nil { + t.Fatal(err) + } + numChildren := len(byParent[parentImage.ID]) + if numChildren != 2 { + t.Fatalf("Expected 2 children, found %d", numChildren) + } +} + func assertNImages(graph *Graph, t *testing.T, n int) { if images, err := graph.All(); err != nil { t.Fatal(err)