diff --git a/api/client/service/update.go b/api/client/service/update.go index 4f432331d7..5cf3aa84cc 100644 --- a/api/client/service/update.go +++ b/api/client/service/update.go @@ -2,6 +2,7 @@ package service import ( "fmt" + "sort" "strings" "time" @@ -216,7 +217,9 @@ func updateService(flags *pflag.FlagSet, spec *swarm.ServiceSpec) error { if spec.EndpointSpec == nil { spec.EndpointSpec = &swarm.EndpointSpec{} } - updatePorts(flags, &spec.EndpointSpec.Ports) + if err := updatePorts(flags, &spec.EndpointSpec.Ports); err != nil { + return err + } } if err := updateLogDriver(flags, &spec.TaskTemplate); err != nil { @@ -369,23 +372,54 @@ func updateMounts(flags *pflag.FlagSet, mounts *[]swarm.Mount) { *mounts = newMounts } -func updatePorts(flags *pflag.FlagSet, portConfig *[]swarm.PortConfig) { +type byPortConfig []swarm.PortConfig + +func (r byPortConfig) Len() int { return len(r) } +func (r byPortConfig) Swap(i, j int) { r[i], r[j] = r[j], r[i] } +func (r byPortConfig) Less(i, j int) bool { + // We convert PortConfig into `port/protocol`, e.g., `80/tcp` + // In updatePorts we already filter out with map so there is duplicate entries + return portConfigToString(&r[i]) < portConfigToString(&r[j]) +} + +func portConfigToString(portConfig *swarm.PortConfig) string { + protocol := portConfig.Protocol + if protocol == "" { + protocol = "tcp" + } + return fmt.Sprintf("%v/%s", portConfig.PublishedPort, protocol) +} + +func updatePorts(flags *pflag.FlagSet, portConfig *[]swarm.PortConfig) error { + // The key of the map is `port/protocol`, e.g., `80/tcp` + portSet := map[string]swarm.PortConfig{} + // Check to see if there are any conflict in flags. if flags.Changed(flagPublishAdd) { values := flags.Lookup(flagPublishAdd).Value.(*opts.ListOpts).GetAll() ports, portBindings, _ := nat.ParsePortSpecs(values) for port := range ports { - *portConfig = append(*portConfig, convertPortToPortConfig(port, portBindings)...) + newConfigs := convertPortToPortConfig(port, portBindings) + for _, entry := range newConfigs { + if v, ok := portSet[portConfigToString(&entry)]; ok && v != entry { + return fmt.Errorf("conflicting port mapping between %v:%v/%s and %v:%v/%s", entry.PublishedPort, entry.TargetPort, entry.Protocol, v.PublishedPort, v.TargetPort, v.Protocol) + } + portSet[portConfigToString(&entry)] = entry + } } } - if !flags.Changed(flagPublishRemove) { - return + // Override previous PortConfig in service if there is any duplicate + for _, entry := range *portConfig { + if _, ok := portSet[portConfigToString(&entry)]; !ok { + portSet[portConfigToString(&entry)] = entry + } } + toRemove := flags.Lookup(flagPublishRemove).Value.(*opts.ListOpts).GetAll() newPorts := []swarm.PortConfig{} portLoop: - for _, port := range *portConfig { + for _, port := range portSet { for _, rawTargetPort := range toRemove { targetPort := nat.Port(rawTargetPort) if equalPort(targetPort, port) { @@ -394,7 +428,10 @@ portLoop: } newPorts = append(newPorts, port) } + // Sort the PortConfig to avoid unnecessary updates + sort.Sort(byPortConfig(newPorts)) *portConfig = newPorts + return nil } func equalPort(targetPort nat.Port, port swarm.PortConfig) bool { diff --git a/api/client/service/update_test.go b/api/client/service/update_test.go index ce0609d264..2c1cf1bcbb 100644 --- a/api/client/service/update_test.go +++ b/api/client/service/update_test.go @@ -141,8 +141,56 @@ func TestUpdatePorts(t *testing.T) { {TargetPort: 555}, } - updatePorts(flags, &portConfigs) + err := updatePorts(flags, &portConfigs) + assert.Equal(t, err, nil) assert.Equal(t, len(portConfigs), 2) - assert.Equal(t, portConfigs[0].TargetPort, uint32(555)) - assert.Equal(t, portConfigs[1].TargetPort, uint32(1000)) + // Do a sort to have the order (might have changed by map) + targetPorts := []int{int(portConfigs[0].TargetPort), int(portConfigs[1].TargetPort)} + sort.Ints(targetPorts) + assert.Equal(t, targetPorts[0], 555) + assert.Equal(t, targetPorts[1], 1000) +} + +func TestUpdatePortsDuplicateEntries(t *testing.T) { + // Test case for #25375 + flags := newUpdateCommand(nil).Flags() + flags.Set("publish-add", "80:80") + + portConfigs := []swarm.PortConfig{ + {TargetPort: 80, PublishedPort: 80}, + } + + err := updatePorts(flags, &portConfigs) + assert.Equal(t, err, nil) + assert.Equal(t, len(portConfigs), 1) + assert.Equal(t, portConfigs[0].TargetPort, uint32(80)) +} + +func TestUpdatePortsDuplicateKeys(t *testing.T) { + // Test case for #25375 + flags := newUpdateCommand(nil).Flags() + flags.Set("publish-add", "80:20") + + portConfigs := []swarm.PortConfig{ + {TargetPort: 80, PublishedPort: 80}, + } + + err := updatePorts(flags, &portConfigs) + assert.Equal(t, err, nil) + assert.Equal(t, len(portConfigs), 1) + assert.Equal(t, portConfigs[0].TargetPort, uint32(20)) +} + +func TestUpdatePortsConflictingFlags(t *testing.T) { + // Test case for #25375 + flags := newUpdateCommand(nil).Flags() + flags.Set("publish-add", "80:80") + flags.Set("publish-add", "80:20") + + portConfigs := []swarm.PortConfig{ + {TargetPort: 80, PublishedPort: 80}, + } + + err := updatePorts(flags, &portConfigs) + assert.Error(t, err, "conflicting port mapping") } diff --git a/integration-cli/docker_cli_swarm_test.go b/integration-cli/docker_cli_swarm_test.go index 8cf5b6e0e9..4acdf3fcb0 100644 --- a/integration-cli/docker_cli_swarm_test.go +++ b/integration-cli/docker_cli_swarm_test.go @@ -191,3 +191,29 @@ func (s *DockerSwarmSuite) TestSwarmNodeTaskListFilter(c *check.C) { c.Assert(out, checker.Not(checker.Contains), name+".2") c.Assert(out, checker.Not(checker.Contains), name+".3") } + +// Test case for #25375 +func (s *DockerSwarmSuite) TestSwarmPublishAdd(c *check.C) { + d := s.AddDaemon(c, true, true) + + name := "top" + out, err := d.Cmd("service", "create", "--name", name, "--label", "x=y", "busybox", "top") + c.Assert(err, checker.IsNil) + c.Assert(strings.TrimSpace(out), checker.Not(checker.Equals), "") + + out, err = d.Cmd("service", "update", "--publish-add", "80:80", name) + c.Assert(err, checker.IsNil) + + out, err = d.Cmd("service", "update", "--publish-add", "80:80", name) + c.Assert(err, checker.IsNil) + + out, err = d.Cmd("service", "update", "--publish-add", "80:80", "--publish-add", "80:20", name) + c.Assert(err, checker.NotNil) + + out, err = d.Cmd("service", "update", "--publish-add", "80:20", name) + c.Assert(err, checker.IsNil) + + out, err = d.Cmd("service", "inspect", "--format", "{{ .Spec.EndpointSpec.Ports }}", name) + c.Assert(err, checker.IsNil) + c.Assert(strings.TrimSpace(out), checker.Equals, "[{ tcp 20 80}]") +}