diff --git a/server.go b/server.go index 646cb44877..69edabf07a 100644 --- a/server.go +++ b/server.go @@ -667,29 +667,57 @@ func (srv *Server) ImagePull(localName string, tag string, out io.Writer, sf *ut // Retrieve the all the images to be uploaded in the correct order // Note: we can't use a map as it is not ordered -func (srv *Server) getImageList(localRepo map[string]string) ([]*registry.ImgData, error) { - var imgList []*registry.ImgData +func (srv *Server) getImageList(localRepo map[string]string) ([][]*registry.ImgData, error) { + imgList := map[string]*registry.ImgData{} + depGraph := utils.NewDependencyGraph() - imageSet := make(map[string]struct{}) for tag, id := range localRepo { img, err := srv.runtime.graph.Get(id) if err != nil { return nil, err } - img.WalkHistory(func(img *Image) error { - if _, exists := imageSet[img.ID]; exists { + depGraph.NewNode(img.ID) + img.WalkHistory(func(current *Image) error { + imgList[current.ID] = ®istry.ImgData{ + ID: current.ID, + Tag: tag, + } + parent, err := current.GetParent() + if err != nil { + return err + } + if parent == nil { return nil } - imageSet[img.ID] = struct{}{} - - imgList = append([]*registry.ImgData{{ - ID: img.ID, - Tag: tag, - }}, imgList...) + depGraph.NewNode(parent.ID) + depGraph.AddDependency(current.ID, parent.ID) return nil }) } - return imgList, nil + + traversalMap, err := depGraph.GenerateTraversalMap() + if err != nil { + return nil, err + } + + utils.Debugf("Traversal map: %v", traversalMap) + result := [][]*registry.ImgData{} + for _, round := range traversalMap { + dataRound := []*registry.ImgData{} + for _, imgID := range round { + dataRound = append(dataRound, imgList[imgID]) + } + result = append(result, dataRound) + } + return result, nil +} + +func flatten(slc [][]*registry.ImgData) []*registry.ImgData { + result := []*registry.ImgData{} + for _, x := range slc { + result = append(result, x...) + } + return result } func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName, remoteName string, localRepo map[string]string, indexEp string, sf *utils.StreamFormatter) error { @@ -698,39 +726,43 @@ func (srv *Server) pushRepository(r *registry.Registry, out io.Writer, localName if err != nil { return err } + flattenedImgList := flatten(imgList) out.Write(sf.FormatStatus("", "Sending image list")) var repoData *registry.RepositoryData - repoData, err = r.PushImageJSONIndex(indexEp, remoteName, imgList, false, nil) + repoData, err = r.PushImageJSONIndex(indexEp, remoteName, flattenedImgList, false, nil) if err != nil { return err } for _, ep := range repoData.Endpoints { out.Write(sf.FormatStatus("", "Pushing repository %s (%d tags)", localName, len(localRepo))) - // For each image within the repo, push them - for _, elem := range imgList { - if _, exists := repoData.ImgList[elem.ID]; exists { - out.Write(sf.FormatStatus("", "Image %s already pushed, skipping", elem.ID)) - continue - } else if r.LookupRemoteImage(elem.ID, ep, repoData.Tokens) { - out.Write(sf.FormatStatus("", "Image %s already pushed, skipping", elem.ID)) - continue - } - if checksum, err := srv.pushImage(r, out, remoteName, elem.ID, ep, repoData.Tokens, sf); err != nil { - // FIXME: Continue on error? - return err - } else { - elem.Checksum = checksum - } - out.Write(sf.FormatStatus("", "Pushing tags for rev [%s] on {%s}", elem.ID, ep+"repositories/"+remoteName+"/tags/"+elem.Tag)) - if err := r.PushRegistryTag(remoteName, elem.ID, elem.Tag, ep, repoData.Tokens); err != nil { - return err + // This section can not be parallelized (each round depends on the previous one) + for _, round := range imgList { + // FIXME: This section can be parallelized + for _, elem := range round { + if _, exists := repoData.ImgList[elem.ID]; exists { + out.Write(sf.FormatStatus("", "Image %s already pushed, skipping", elem.ID)) + continue + } else if r.LookupRemoteImage(elem.ID, ep, repoData.Tokens) { + out.Write(sf.FormatStatus("", "Image %s already pushed, skipping", elem.ID)) + continue + } + if checksum, err := srv.pushImage(r, out, remoteName, elem.ID, ep, repoData.Tokens, sf); err != nil { + // FIXME: Continue on error? + return err + } else { + elem.Checksum = checksum + } + out.Write(sf.FormatStatus("", "Pushing tags for rev [%s] on {%s}", elem.ID, ep+"repositories/"+remoteName+"/tags/"+elem.Tag)) + if err := r.PushRegistryTag(remoteName, elem.ID, elem.Tag, ep, repoData.Tokens); err != nil { + return err + } } } } - if _, err := r.PushImageJSONIndex(indexEp, remoteName, imgList, true, repoData.Endpoints); err != nil { + if _, err := r.PushImageJSONIndex(indexEp, remoteName, flattenedImgList, true, repoData.Endpoints); err != nil { return err } diff --git a/utils/utils.go b/utils/utils.go index e8cf08aaba..b761da0fbd 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -889,3 +889,119 @@ func UserLookup(uid string) (*user.User, error) { } return nil, fmt.Errorf("User not found in /etc/passwd") } + +type DependencyGraph struct{ + nodes map[string]*DependencyNode +} + +type DependencyNode struct{ + id string + deps map[*DependencyNode]bool +} + +func NewDependencyGraph() DependencyGraph { + return DependencyGraph{ + nodes: map[string]*DependencyNode{}, + } +} + +func (graph *DependencyGraph) addNode(node *DependencyNode) string { + if graph.nodes[node.id] == nil { + graph.nodes[node.id] = node + } + return node.id +} + +func (graph *DependencyGraph) NewNode(id string) string { + if graph.nodes[id] != nil { + return id + } + nd := &DependencyNode{ + id: id, + deps: map[*DependencyNode]bool{}, + } + graph.addNode(nd) + return id +} + +func (graph *DependencyGraph) AddDependency(node, to string) error { + if graph.nodes[node] == nil { + return fmt.Errorf("Node %s does not belong to this graph", node) + } + + if graph.nodes[to] == nil { + return fmt.Errorf("Node %s does not belong to this graph", to) + } + + if node == to { + return fmt.Errorf("Dependency loops are forbidden!") + } + + graph.nodes[node].addDependency(graph.nodes[to]) + return nil +} + +func (node *DependencyNode) addDependency(to *DependencyNode) bool { + node.deps[to] = true + return node.deps[to] +} + +func (node *DependencyNode) Degree() int { + return len(node.deps) +} + +// The magic happens here :: +func (graph *DependencyGraph) GenerateTraversalMap() ([][]string, error) { + Debugf("Generating traversal map. Nodes: %d", len(graph.nodes)) + result := [][]string{} + processed := map[*DependencyNode]bool{} + // As long as we haven't processed all nodes... + for len(processed) < len(graph.nodes) { + // Use a temporary buffer for processed nodes, otherwise + // nodes that depend on each other could end up in the same round. + tmp_processed := []*DependencyNode{} + for _, node := range graph.nodes { + // If the node has more dependencies than what we have cleared, + // it won't be valid for this round. + if node.Degree() > len(processed) { + continue + } + // If it's already processed, get to the next one + if processed[node] { + continue + } + // It's not been processed yet and has 0 deps. Add it! + // (this is a shortcut for what we're doing below) + if node.Degree() == 0 { + tmp_processed = append(tmp_processed, node) + continue + } + // If at least one dep hasn't been processed yet, we can't + // add it. + ok := true + for dep, _ := range node.deps { + if !processed[dep] { + ok = false + break + } + } + // All deps have already been processed. Add it! + if ok { + tmp_processed = append(tmp_processed, node) + } + } + Debugf("Round %d: found %d available nodes", len(result), len(tmp_processed)) + // If no progress has been made this round, + // that means we have circular dependencies. + if len(tmp_processed) == 0 { + return nil, fmt.Errorf("Could not find a solution to this dependency graph") + } + round := []string{} + for _, nd := range tmp_processed { + round = append(round, nd.id) + processed[nd] = true + } + result = append(result, round) + } + return result, nil +} \ No newline at end of file diff --git a/utils/utils_test.go b/utils/utils_test.go index 3341650860..be796b2381 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -365,3 +365,60 @@ func TestParseRelease(t *testing.T) { assertParseRelease(t, "3.4.54.longterm-1", &KernelVersionInfo{Kernel: 3, Major: 4, Minor: 54, Flavor: "1"}, 0) assertParseRelease(t, "3.8.0-19-generic", &KernelVersionInfo{Kernel: 3, Major: 8, Minor: 0, Flavor: "19-generic"}, 0) } + + +func TestDependencyGraphCircular(t *testing.T) { + g1 := NewDependencyGraph() + a := g1.NewNode("a") + b := g1.NewNode("b") + g1.AddDependency(a, b) + g1.AddDependency(b, a) + res, err := g1.GenerateTraversalMap() + if res != nil { + t.Fatalf("Expected nil result") + } + if err == nil { + t.Fatalf("Expected error (circular graph can not be resolved)") + } +} + +func TestDependencyGraph(t *testing.T) { + g1 := NewDependencyGraph() + a := g1.NewNode("a") + b := g1.NewNode("b") + c := g1.NewNode("c") + d := g1.NewNode("d") + g1.AddDependency(b, a) + g1.AddDependency(c, a) + g1.AddDependency(d, c) + g1.AddDependency(d, b) + res, err := g1.GenerateTraversalMap() + + if err != nil { + t.Fatalf("%s", err) + } + + if res == nil { + t.Fatalf("Unexpected nil result") + } + + if len(res) != 3 { + t.Fatalf("Expected map of length 3, found %d instead", len(res)) + } + + if len(res[0]) != 1 || res[0][0] != "a" { + t.Fatalf("Expected [a], found %v instead", res[0]) + } + + if len(res[1]) != 2 { + t.Fatalf("Expected 2 nodes for step 2, found %d", len(res[1])) + } + + if (res[1][0] != "b" && res[1][1] != "b") || (res[1][0] != "c" && res[1][1] != "c") { + t.Fatalf("Expected [b, c], found %v instead", res[1]) + } + + if len(res[2]) != 1 || res[2][0] != "d" { + t.Fatalf("Expected [d], found %v instead", res[2]) + } +} \ No newline at end of file