mirror of
				https://github.com/moby/moby.git
				synced 2022-11-09 12:21:53 -05:00 
			
		
		
		
	Add pubsub package to handle robust publisher
Signed-off-by: Michael Crosby <crosbymichael@gmail.com>
This commit is contained in:
		
							parent
							
								
									2d4fc1de05
								
							
						
					
					
						commit
						2f46b7601a
					
				
					 6 changed files with 248 additions and 156 deletions
				
			
		| 
						 | 
				
			
			@ -1,11 +1,8 @@
 | 
			
		|||
// This package is used for API stability in the types and response to the
 | 
			
		||||
// consumers of the API stats endpoint.
 | 
			
		||||
package stats
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/docker/libcontainer"
 | 
			
		||||
	"github.com/docker/libcontainer/cgroups"
 | 
			
		||||
)
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
type ThrottlingData struct {
 | 
			
		||||
	// Number of periods with throttling active
 | 
			
		||||
| 
						 | 
				
			
			@ -88,69 +85,3 @@ type Stats struct {
 | 
			
		|||
	MemoryStats MemoryStats `json:"memory_stats,omitempty"`
 | 
			
		||||
	BlkioStats  BlkioStats  `json:"blkio_stats,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ToStats converts the libcontainer.ContainerStats to the api specific
 | 
			
		||||
// structs.  This is done to preserve API compatibility and versioning.
 | 
			
		||||
func ToStats(ls *libcontainer.ContainerStats) *Stats {
 | 
			
		||||
	s := &Stats{}
 | 
			
		||||
	if ls.NetworkStats != nil {
 | 
			
		||||
		s.Network = Network{
 | 
			
		||||
			RxBytes:   ls.NetworkStats.RxBytes,
 | 
			
		||||
			RxPackets: ls.NetworkStats.RxPackets,
 | 
			
		||||
			RxErrors:  ls.NetworkStats.RxErrors,
 | 
			
		||||
			RxDropped: ls.NetworkStats.RxDropped,
 | 
			
		||||
			TxBytes:   ls.NetworkStats.TxBytes,
 | 
			
		||||
			TxPackets: ls.NetworkStats.TxPackets,
 | 
			
		||||
			TxErrors:  ls.NetworkStats.TxErrors,
 | 
			
		||||
			TxDropped: ls.NetworkStats.TxDropped,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	cs := ls.CgroupStats
 | 
			
		||||
	if cs != nil {
 | 
			
		||||
		s.BlkioStats = BlkioStats{
 | 
			
		||||
			IoServiceBytesRecursive: copyBlkioEntry(cs.BlkioStats.IoServiceBytesRecursive),
 | 
			
		||||
			IoServicedRecursive:     copyBlkioEntry(cs.BlkioStats.IoServicedRecursive),
 | 
			
		||||
			IoQueuedRecursive:       copyBlkioEntry(cs.BlkioStats.IoQueuedRecursive),
 | 
			
		||||
			IoServiceTimeRecursive:  copyBlkioEntry(cs.BlkioStats.IoServiceTimeRecursive),
 | 
			
		||||
			IoWaitTimeRecursive:     copyBlkioEntry(cs.BlkioStats.IoWaitTimeRecursive),
 | 
			
		||||
			IoMergedRecursive:       copyBlkioEntry(cs.BlkioStats.IoMergedRecursive),
 | 
			
		||||
			IoTimeRecursive:         copyBlkioEntry(cs.BlkioStats.IoTimeRecursive),
 | 
			
		||||
			SectorsRecursive:        copyBlkioEntry(cs.BlkioStats.SectorsRecursive),
 | 
			
		||||
		}
 | 
			
		||||
		cpu := cs.CpuStats
 | 
			
		||||
		s.CpuStats = CpuStats{
 | 
			
		||||
			CpuUsage: CpuUsage{
 | 
			
		||||
				TotalUsage:        cpu.CpuUsage.TotalUsage,
 | 
			
		||||
				PercpuUsage:       cpu.CpuUsage.PercpuUsage,
 | 
			
		||||
				UsageInKernelmode: cpu.CpuUsage.UsageInKernelmode,
 | 
			
		||||
				UsageInUsermode:   cpu.CpuUsage.UsageInUsermode,
 | 
			
		||||
			},
 | 
			
		||||
			ThrottlingData: ThrottlingData{
 | 
			
		||||
				Periods:          cpu.ThrottlingData.Periods,
 | 
			
		||||
				ThrottledPeriods: cpu.ThrottlingData.ThrottledPeriods,
 | 
			
		||||
				ThrottledTime:    cpu.ThrottlingData.ThrottledTime,
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		mem := cs.MemoryStats
 | 
			
		||||
		s.MemoryStats = MemoryStats{
 | 
			
		||||
			Usage:    mem.Usage,
 | 
			
		||||
			MaxUsage: mem.MaxUsage,
 | 
			
		||||
			Stats:    mem.Stats,
 | 
			
		||||
			Failcnt:  mem.Failcnt,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func copyBlkioEntry(entries []cgroups.BlkioStatEntry) []BlkioStatEntry {
 | 
			
		||||
	out := make([]BlkioStatEntry, len(entries))
 | 
			
		||||
	for i, re := range entries {
 | 
			
		||||
		out[i] = BlkioStatEntry{
 | 
			
		||||
			Major: re.Major,
 | 
			
		||||
			Minor: re.Minor,
 | 
			
		||||
			Op:    re.Op,
 | 
			
		||||
			Value: re.Value,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return out
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1099,7 +1099,7 @@ func (daemon *Daemon) Stats(c *Container) (*execdriver.ResourceStats, error) {
 | 
			
		|||
	return daemon.execDriver.Stats(c.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (daemon *Daemon) SubscribeToContainerStats(name string) (chan *execdriver.ResourceStats, error) {
 | 
			
		||||
func (daemon *Daemon) SubscribeToContainerStats(name string) (chan interface{}, error) {
 | 
			
		||||
	c := daemon.Get(name)
 | 
			
		||||
	if c == nil {
 | 
			
		||||
		return nil, fmt.Errorf("no such container")
 | 
			
		||||
| 
						 | 
				
			
			@ -1108,7 +1108,7 @@ func (daemon *Daemon) SubscribeToContainerStats(name string) (chan *execdriver.R
 | 
			
		|||
	return ch, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (daemon *Daemon) UnsubscribeToContainerStats(name string, ch chan *execdriver.ResourceStats) error {
 | 
			
		||||
func (daemon *Daemon) UnsubscribeToContainerStats(name string, ch chan interface{}) error {
 | 
			
		||||
	c := daemon.Get(name)
 | 
			
		||||
	if c == nil {
 | 
			
		||||
		return fmt.Errorf("no such container")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,25 +4,95 @@ import (
 | 
			
		|||
	"encoding/json"
 | 
			
		||||
 | 
			
		||||
	"github.com/docker/docker/api/stats"
 | 
			
		||||
	"github.com/docker/docker/daemon/execdriver"
 | 
			
		||||
	"github.com/docker/docker/engine"
 | 
			
		||||
	"github.com/docker/libcontainer"
 | 
			
		||||
	"github.com/docker/libcontainer/cgroups"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (daemon *Daemon) ContainerStats(job *engine.Job) engine.Status {
 | 
			
		||||
	s, err := daemon.SubscribeToContainerStats(job.Args[0])
 | 
			
		||||
	updates, err := daemon.SubscribeToContainerStats(job.Args[0])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return job.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
	enc := json.NewEncoder(job.Stdout)
 | 
			
		||||
	for update := range s {
 | 
			
		||||
		ss := stats.ToStats(update.ContainerStats)
 | 
			
		||||
	for v := range updates {
 | 
			
		||||
		update := v.(*execdriver.ResourceStats)
 | 
			
		||||
		ss := convertToAPITypes(update.ContainerStats)
 | 
			
		||||
		ss.MemoryStats.Limit = uint64(update.MemoryLimit)
 | 
			
		||||
		ss.Read = update.Read
 | 
			
		||||
		ss.CpuStats.SystemUsage = update.SystemUsage
 | 
			
		||||
		if err := enc.Encode(ss); err != nil {
 | 
			
		||||
			// TODO: handle the specific broken pipe
 | 
			
		||||
			daemon.UnsubscribeToContainerStats(job.Args[0], s)
 | 
			
		||||
			daemon.UnsubscribeToContainerStats(job.Args[0], updates)
 | 
			
		||||
			return job.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return engine.StatusOK
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// convertToAPITypes converts the libcontainer.ContainerStats to the api specific
 | 
			
		||||
// structs.  This is done to preserve API compatibility and versioning.
 | 
			
		||||
func convertToAPITypes(ls *libcontainer.ContainerStats) *stats.Stats {
 | 
			
		||||
	s := &stats.Stats{}
 | 
			
		||||
	if ls.NetworkStats != nil {
 | 
			
		||||
		s.Network = stats.Network{
 | 
			
		||||
			RxBytes:   ls.NetworkStats.RxBytes,
 | 
			
		||||
			RxPackets: ls.NetworkStats.RxPackets,
 | 
			
		||||
			RxErrors:  ls.NetworkStats.RxErrors,
 | 
			
		||||
			RxDropped: ls.NetworkStats.RxDropped,
 | 
			
		||||
			TxBytes:   ls.NetworkStats.TxBytes,
 | 
			
		||||
			TxPackets: ls.NetworkStats.TxPackets,
 | 
			
		||||
			TxErrors:  ls.NetworkStats.TxErrors,
 | 
			
		||||
			TxDropped: ls.NetworkStats.TxDropped,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	cs := ls.CgroupStats
 | 
			
		||||
	if cs != nil {
 | 
			
		||||
		s.BlkioStats = stats.BlkioStats{
 | 
			
		||||
			IoServiceBytesRecursive: copyBlkioEntry(cs.BlkioStats.IoServiceBytesRecursive),
 | 
			
		||||
			IoServicedRecursive:     copyBlkioEntry(cs.BlkioStats.IoServicedRecursive),
 | 
			
		||||
			IoQueuedRecursive:       copyBlkioEntry(cs.BlkioStats.IoQueuedRecursive),
 | 
			
		||||
			IoServiceTimeRecursive:  copyBlkioEntry(cs.BlkioStats.IoServiceTimeRecursive),
 | 
			
		||||
			IoWaitTimeRecursive:     copyBlkioEntry(cs.BlkioStats.IoWaitTimeRecursive),
 | 
			
		||||
			IoMergedRecursive:       copyBlkioEntry(cs.BlkioStats.IoMergedRecursive),
 | 
			
		||||
			IoTimeRecursive:         copyBlkioEntry(cs.BlkioStats.IoTimeRecursive),
 | 
			
		||||
			SectorsRecursive:        copyBlkioEntry(cs.BlkioStats.SectorsRecursive),
 | 
			
		||||
		}
 | 
			
		||||
		cpu := cs.CpuStats
 | 
			
		||||
		s.CpuStats = stats.CpuStats{
 | 
			
		||||
			CpuUsage: stats.CpuUsage{
 | 
			
		||||
				TotalUsage:        cpu.CpuUsage.TotalUsage,
 | 
			
		||||
				PercpuUsage:       cpu.CpuUsage.PercpuUsage,
 | 
			
		||||
				UsageInKernelmode: cpu.CpuUsage.UsageInKernelmode,
 | 
			
		||||
				UsageInUsermode:   cpu.CpuUsage.UsageInUsermode,
 | 
			
		||||
			},
 | 
			
		||||
			ThrottlingData: stats.ThrottlingData{
 | 
			
		||||
				Periods:          cpu.ThrottlingData.Periods,
 | 
			
		||||
				ThrottledPeriods: cpu.ThrottlingData.ThrottledPeriods,
 | 
			
		||||
				ThrottledTime:    cpu.ThrottlingData.ThrottledTime,
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		mem := cs.MemoryStats
 | 
			
		||||
		s.MemoryStats = stats.MemoryStats{
 | 
			
		||||
			Usage:    mem.Usage,
 | 
			
		||||
			MaxUsage: mem.MaxUsage,
 | 
			
		||||
			Stats:    mem.Stats,
 | 
			
		||||
			Failcnt:  mem.Failcnt,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func copyBlkioEntry(entries []cgroups.BlkioStatEntry) []stats.BlkioStatEntry {
 | 
			
		||||
	out := make([]stats.BlkioStatEntry, len(entries))
 | 
			
		||||
	for i, re := range entries {
 | 
			
		||||
		out[i] = stats.BlkioStatEntry{
 | 
			
		||||
			Major: re.Major,
 | 
			
		||||
			Minor: re.Minor,
 | 
			
		||||
			Op:    re.Op,
 | 
			
		||||
			Value: re.Value,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return out
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@ import (
 | 
			
		|||
 | 
			
		||||
	log "github.com/Sirupsen/logrus"
 | 
			
		||||
	"github.com/docker/docker/daemon/execdriver"
 | 
			
		||||
	"github.com/docker/docker/pkg/pubsub"
 | 
			
		||||
	"github.com/docker/libcontainer/system"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -21,114 +22,75 @@ import (
 | 
			
		|||
func newStatsCollector(interval time.Duration) *statsCollector {
 | 
			
		||||
	s := &statsCollector{
 | 
			
		||||
		interval:   interval,
 | 
			
		||||
		containers: make(map[string]*statsData),
 | 
			
		||||
		publishers: make(map[*Container]*pubsub.Publisher),
 | 
			
		||||
		clockTicks: uint64(system.GetClockTicks()),
 | 
			
		||||
	}
 | 
			
		||||
	s.start()
 | 
			
		||||
	go s.run()
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type statsData struct {
 | 
			
		||||
	c         *Container
 | 
			
		||||
	lastStats *execdriver.ResourceStats
 | 
			
		||||
	subs      []chan *execdriver.ResourceStats
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// statsCollector manages and provides container resource stats
 | 
			
		||||
type statsCollector struct {
 | 
			
		||||
	m          sync.Mutex
 | 
			
		||||
	interval   time.Duration
 | 
			
		||||
	clockTicks uint64
 | 
			
		||||
	containers map[string]*statsData
 | 
			
		||||
	publishers map[*Container]*pubsub.Publisher
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// collect registers the container with the collector and adds it to
 | 
			
		||||
// the event loop for collection on the specified interval returning
 | 
			
		||||
// a channel for the subscriber to receive on.
 | 
			
		||||
func (s *statsCollector) collect(c *Container) chan *execdriver.ResourceStats {
 | 
			
		||||
func (s *statsCollector) collect(c *Container) chan interface{} {
 | 
			
		||||
	s.m.Lock()
 | 
			
		||||
	defer s.m.Unlock()
 | 
			
		||||
	ch := make(chan *execdriver.ResourceStats, 1024)
 | 
			
		||||
	if _, exists := s.containers[c.ID]; exists {
 | 
			
		||||
		s.containers[c.ID].subs = append(s.containers[c.ID].subs, ch)
 | 
			
		||||
		return ch
 | 
			
		||||
	publisher, exists := s.publishers[c]
 | 
			
		||||
	if !exists {
 | 
			
		||||
		publisher = pubsub.NewPublisher(100*time.Millisecond, 1024)
 | 
			
		||||
		s.publishers[c] = publisher
 | 
			
		||||
	}
 | 
			
		||||
	s.containers[c.ID] = &statsData{
 | 
			
		||||
		c: c,
 | 
			
		||||
		subs: []chan *execdriver.ResourceStats{
 | 
			
		||||
			ch,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	return ch
 | 
			
		||||
	return publisher.Subscribe()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// stopCollection closes the channels for all subscribers and removes
 | 
			
		||||
// the container from metrics collection.
 | 
			
		||||
func (s *statsCollector) stopCollection(c *Container) {
 | 
			
		||||
	s.m.Lock()
 | 
			
		||||
	defer s.m.Unlock()
 | 
			
		||||
	d := s.containers[c.ID]
 | 
			
		||||
	if d == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	for _, sub := range d.subs {
 | 
			
		||||
		close(sub)
 | 
			
		||||
	}
 | 
			
		||||
	delete(s.containers, c.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// unsubscribe removes a specific subscriber from receiving updates for a
 | 
			
		||||
// container's stats.
 | 
			
		||||
func (s *statsCollector) unsubscribe(c *Container, ch chan *execdriver.ResourceStats) {
 | 
			
		||||
	s.m.Lock()
 | 
			
		||||
	cd := s.containers[c.ID]
 | 
			
		||||
	for i, sub := range cd.subs {
 | 
			
		||||
		if ch == sub {
 | 
			
		||||
			cd.subs = append(cd.subs[:i], cd.subs[i+1:]...)
 | 
			
		||||
			close(ch)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// if there are no more subscribers then remove the entire container
 | 
			
		||||
	// from collection.
 | 
			
		||||
	if len(cd.subs) == 0 {
 | 
			
		||||
		delete(s.containers, c.ID)
 | 
			
		||||
	if publisher, exists := s.publishers[c]; exists {
 | 
			
		||||
		publisher.Close()
 | 
			
		||||
		delete(s.publishers, c)
 | 
			
		||||
	}
 | 
			
		||||
	s.m.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *statsCollector) start() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for _ = range time.Tick(s.interval) {
 | 
			
		||||
			s.m.Lock()
 | 
			
		||||
			for id, d := range s.containers {
 | 
			
		||||
				systemUsage, err := s.getSystemCpuUsage()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Errorf("collecting system cpu usage for %s: %v", id, err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				stats, err := d.c.Stats()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					if err == execdriver.ErrNotRunning {
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					// if the error is not because the container is currently running then
 | 
			
		||||
					// evict the container from the collector and close the channel for
 | 
			
		||||
					// any subscribers currently waiting on changes.
 | 
			
		||||
					log.Errorf("collecting stats for %s: %v", id, err)
 | 
			
		||||
					for _, sub := range s.containers[id].subs {
 | 
			
		||||
						close(sub)
 | 
			
		||||
					}
 | 
			
		||||
					delete(s.containers, id)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				stats.SystemUsage = systemUsage
 | 
			
		||||
				for _, sub := range s.containers[id].subs {
 | 
			
		||||
					sub <- stats
 | 
			
		||||
				}
 | 
			
		||||
// unsubscribe removes a specific subscriber from receiving updates for a container's stats.
 | 
			
		||||
func (s *statsCollector) unsubscribe(c *Container, ch chan interface{}) {
 | 
			
		||||
	s.m.Lock()
 | 
			
		||||
	publisher := s.publishers[c]
 | 
			
		||||
	if publisher != nil {
 | 
			
		||||
		publisher.Evict(ch)
 | 
			
		||||
	}
 | 
			
		||||
	s.m.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *statsCollector) run() {
 | 
			
		||||
	for _ = range time.Tick(s.interval) {
 | 
			
		||||
		for container, publisher := range s.publishers {
 | 
			
		||||
			systemUsage, err := s.getSystemCpuUsage()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Errorf("collecting system cpu usage for %s: %v", container.ID, err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			s.m.Unlock()
 | 
			
		||||
			stats, err := container.Stats()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				if err != execdriver.ErrNotRunning {
 | 
			
		||||
					log.Errorf("collecting stats for %s: %v", container.ID, err)
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			stats.SystemUsage = systemUsage
 | 
			
		||||
			publisher.Publish(stats)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const nanoSeconds = 1e9
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										66
									
								
								pkg/pubsub/publisher.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								pkg/pubsub/publisher.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,66 @@
 | 
			
		|||
package pubsub
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewPublisher creates a new pub/sub publisher to broadcast messages.
 | 
			
		||||
// The duration is used as the send timeout as to not block the publisher publishing
 | 
			
		||||
// messages to other clients if one client is slow or unresponsive.
 | 
			
		||||
// The buffer is used when creating new channels for subscribers.
 | 
			
		||||
func NewPublisher(publishTimeout time.Duration, buffer int) *Publisher {
 | 
			
		||||
	return &Publisher{
 | 
			
		||||
		buffer:      buffer,
 | 
			
		||||
		timeout:     publishTimeout,
 | 
			
		||||
		subscribers: make(map[subscriber]struct{}),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type subscriber chan interface{}
 | 
			
		||||
 | 
			
		||||
type Publisher struct {
 | 
			
		||||
	m           sync.RWMutex
 | 
			
		||||
	buffer      int
 | 
			
		||||
	timeout     time.Duration
 | 
			
		||||
	subscribers map[subscriber]struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Subscribe adds a new subscriber to the publisher returning the channel.
 | 
			
		||||
func (p *Publisher) Subscribe() chan interface{} {
 | 
			
		||||
	ch := make(chan interface{}, p.buffer)
 | 
			
		||||
	p.m.Lock()
 | 
			
		||||
	p.subscribers[ch] = struct{}{}
 | 
			
		||||
	p.m.Unlock()
 | 
			
		||||
	return ch
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Evict removes the specified subscriber from receiving any more messages.
 | 
			
		||||
func (p *Publisher) Evict(sub chan interface{}) {
 | 
			
		||||
	p.m.Lock()
 | 
			
		||||
	delete(p.subscribers, sub)
 | 
			
		||||
	close(sub)
 | 
			
		||||
	p.m.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish sends the data in v to all subscribers currently registered with the publisher.
 | 
			
		||||
func (p *Publisher) Publish(v interface{}) {
 | 
			
		||||
	p.m.RLock()
 | 
			
		||||
	for sub := range p.subscribers {
 | 
			
		||||
		// send under a select as to not block if the receiver is unavailable
 | 
			
		||||
		select {
 | 
			
		||||
		case sub <- v:
 | 
			
		||||
		case <-time.After(p.timeout):
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	p.m.RUnlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close closes the channels to all subscribers registered with the publisher.
 | 
			
		||||
func (p *Publisher) Close() {
 | 
			
		||||
	p.m.Lock()
 | 
			
		||||
	for sub := range p.subscribers {
 | 
			
		||||
		close(sub)
 | 
			
		||||
	}
 | 
			
		||||
	p.m.Unlock()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										63
									
								
								pkg/pubsub/publisher_test.go
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								pkg/pubsub/publisher_test.go
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,63 @@
 | 
			
		|||
package pubsub
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestSendToOneSub(t *testing.T) {
 | 
			
		||||
	p := NewPublisher(100*time.Millisecond, 10)
 | 
			
		||||
	c := p.Subscribe()
 | 
			
		||||
 | 
			
		||||
	p.Publish("hi")
 | 
			
		||||
 | 
			
		||||
	msg := <-c
 | 
			
		||||
	if msg.(string) != "hi" {
 | 
			
		||||
		t.Fatalf("expected message hi but received %v", msg)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSendToMultipleSubs(t *testing.T) {
 | 
			
		||||
	p := NewPublisher(100*time.Millisecond, 10)
 | 
			
		||||
	subs := []chan interface{}{}
 | 
			
		||||
	subs = append(subs, p.Subscribe(), p.Subscribe(), p.Subscribe())
 | 
			
		||||
 | 
			
		||||
	p.Publish("hi")
 | 
			
		||||
 | 
			
		||||
	for _, c := range subs {
 | 
			
		||||
		msg := <-c
 | 
			
		||||
		if msg.(string) != "hi" {
 | 
			
		||||
			t.Fatalf("expected message hi but received %v", msg)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestEvictOneSub(t *testing.T) {
 | 
			
		||||
	p := NewPublisher(100*time.Millisecond, 10)
 | 
			
		||||
	s1 := p.Subscribe()
 | 
			
		||||
	s2 := p.Subscribe()
 | 
			
		||||
 | 
			
		||||
	p.Evict(s1)
 | 
			
		||||
	p.Publish("hi")
 | 
			
		||||
	if _, ok := <-s1; ok {
 | 
			
		||||
		t.Fatal("expected s1 to not receive the published message")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msg := <-s2
 | 
			
		||||
	if msg.(string) != "hi" {
 | 
			
		||||
		t.Fatalf("expected message hi but received %v", msg)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestClosePublisher(t *testing.T) {
 | 
			
		||||
	p := NewPublisher(100*time.Millisecond, 10)
 | 
			
		||||
	subs := []chan interface{}{}
 | 
			
		||||
	subs = append(subs, p.Subscribe(), p.Subscribe(), p.Subscribe())
 | 
			
		||||
	p.Close()
 | 
			
		||||
 | 
			
		||||
	for _, c := range subs {
 | 
			
		||||
		if _, ok := <-c; ok {
 | 
			
		||||
			t.Fatal("expected all subscriber channels to be closed")
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue