1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00
moby--moby/vendor/src/github.com/docker/swarmkit/manager/dispatcher/dispatcher.go

769 lines
22 KiB
Go
Raw Normal View History

package dispatcher
import (
"errors"
"fmt"
"reflect"
"sort"
"sync"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/transport"
"github.com/Sirupsen/logrus"
"github.com/docker/swarmkit/api"
"github.com/docker/swarmkit/ca"
"github.com/docker/swarmkit/log"
"github.com/docker/swarmkit/manager/state"
"github.com/docker/swarmkit/manager/state/store"
"github.com/docker/swarmkit/manager/state/watch"
"github.com/docker/swarmkit/protobuf/ptypes"
"golang.org/x/net/context"
)
const (
// DefaultHeartBeatPeriod is used for setting default value in cluster config
// and in case if cluster config is missing.
DefaultHeartBeatPeriod = 5 * time.Second
defaultHeartBeatEpsilon = 500 * time.Millisecond
defaultGracePeriodMultiplier = 3
defaultRateLimitPeriod = 16 * time.Second
// maxBatchItems is the threshold of queued writes that should
// trigger an actual transaction to commit them to the shared store.
maxBatchItems = 10000
// maxBatchInterval needs to strike a balance between keeping
// latency low, and realizing opportunities to combine many writes
// into a single transaction. A fraction of a second feels about
// right.
maxBatchInterval = 100 * time.Millisecond
)
var (
// ErrNodeAlreadyRegistered returned if node with same ID was already
// registered with this dispatcher.
ErrNodeAlreadyRegistered = errors.New("node already registered")
// ErrNodeNotRegistered returned if node with such ID wasn't registered
// with this dispatcher.
ErrNodeNotRegistered = errors.New("node not registered")
// ErrSessionInvalid returned when the session in use is no longer valid.
// The node should re-register and start a new session.
ErrSessionInvalid = errors.New("session invalid")
// ErrNodeNotFound returned when the Node doesn't exists in raft.
ErrNodeNotFound = errors.New("node not found")
)
// Config is configuration for Dispatcher. For default you should use
// DefautConfig.
type Config struct {
// Addr configures the address the dispatcher reports to agents.
Addr string
HeartbeatPeriod time.Duration
HeartbeatEpsilon time.Duration
// RateLimitPeriod specifies how often node with same ID can try to register
// new session.
RateLimitPeriod time.Duration
GracePeriodMultiplier int
}
// DefaultConfig returns default config for Dispatcher.
func DefaultConfig() *Config {
return &Config{
HeartbeatPeriod: DefaultHeartBeatPeriod,
HeartbeatEpsilon: defaultHeartBeatEpsilon,
RateLimitPeriod: defaultRateLimitPeriod,
GracePeriodMultiplier: defaultGracePeriodMultiplier,
}
}
// Cluster is interface which represent raft cluster. mananger/state/raft.Node
// is implenents it. This interface needed only for easier unit-testing.
type Cluster interface {
GetMemberlist() map[uint64]*api.RaftMember
MemoryStore() *store.MemoryStore
}
// Dispatcher is responsible for dispatching tasks and tracking agent health.
type Dispatcher struct {
mu sync.Mutex
addr string
nodes *nodeStore
store *store.MemoryStore
mgrQueue *watch.Queue
lastSeenManagers []*api.WeightedPeer
networkBootstrapKeys []*api.EncryptionKey
keyMgrQueue *watch.Queue
config *Config
cluster Cluster
ctx context.Context
cancel context.CancelFunc
taskUpdates map[string]*api.TaskStatus // indexed by task ID
taskUpdatesLock sync.Mutex
processTaskUpdatesTrigger chan struct{}
}
// weightedPeerByNodeID is a sort wrapper for []*api.WeightedPeer
type weightedPeerByNodeID []*api.WeightedPeer
func (b weightedPeerByNodeID) Less(i, j int) bool { return b[i].Peer.NodeID < b[j].Peer.NodeID }
func (b weightedPeerByNodeID) Len() int { return len(b) }
func (b weightedPeerByNodeID) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
// New returns Dispatcher with cluster interface(usually raft.Node).
// NOTE: each handler which does something with raft must add to Dispatcher.wg
func New(cluster Cluster, c *Config) *Dispatcher {
return &Dispatcher{
addr: c.Addr,
nodes: newNodeStore(c.HeartbeatPeriod, c.HeartbeatEpsilon, c.GracePeriodMultiplier, c.RateLimitPeriod),
store: cluster.MemoryStore(),
cluster: cluster,
mgrQueue: watch.NewQueue(16),
keyMgrQueue: watch.NewQueue(16),
taskUpdates: make(map[string]*api.TaskStatus),
processTaskUpdatesTrigger: make(chan struct{}, 1),
config: c,
}
}
func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
members := cluster.GetMemberlist()
var mgrs []*api.WeightedPeer
for _, m := range members {
mgrs = append(mgrs, &api.WeightedPeer{
Peer: &api.Peer{
NodeID: m.NodeID,
Addr: m.Addr,
},
Weight: 1,
})
}
return mgrs
}
// Run runs dispatcher tasks which should be run on leader dispatcher.
// Dispatcher can be stopped with cancelling ctx or calling Stop().
func (d *Dispatcher) Run(ctx context.Context) error {
d.mu.Lock()
if d.isRunning() {
d.mu.Unlock()
return fmt.Errorf("dispatcher is already running")
}
logger := log.G(ctx).WithField("module", "dispatcher")
ctx = log.WithLogger(ctx, logger)
if err := d.markNodesUnknown(ctx); err != nil {
logger.Errorf(`failed to move all nodes to "unknown" state: %v`, err)
}
configWatcher, cancel, err := store.ViewAndWatch(
d.store,
func(readTx store.ReadTx) error {
clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName))
if err != nil {
return err
}
if err == nil && len(clusters) == 1 {
heartbeatPeriod, err := ptypes.Duration(clusters[0].Spec.Dispatcher.HeartbeatPeriod)
if err == nil && heartbeatPeriod > 0 {
d.config.HeartbeatPeriod = heartbeatPeriod
}
if clusters[0].NetworkBootstrapKeys != nil {
d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys
}
}
return nil
},
state.EventUpdateCluster{},
)
if err != nil {
d.mu.Unlock()
return err
}
defer cancel()
d.ctx, d.cancel = context.WithCancel(ctx)
d.mu.Unlock()
publishManagers := func() {
mgrs := getWeightedPeers(d.cluster)
sort.Sort(weightedPeerByNodeID(mgrs))
d.mu.Lock()
if reflect.DeepEqual(mgrs, d.lastSeenManagers) {
d.mu.Unlock()
return
}
d.lastSeenManagers = mgrs
d.mu.Unlock()
d.mgrQueue.Publish(mgrs)
}
publishManagers()
publishTicker := time.NewTicker(1 * time.Second)
defer publishTicker.Stop()
batchTimer := time.NewTimer(maxBatchInterval)
defer batchTimer.Stop()
for {
select {
case <-publishTicker.C:
publishManagers()
case <-d.processTaskUpdatesTrigger:
d.processTaskUpdates()
batchTimer.Reset(maxBatchInterval)
case <-batchTimer.C:
d.processTaskUpdates()
batchTimer.Reset(maxBatchInterval)
case v := <-configWatcher:
cluster := v.(state.EventUpdateCluster)
d.mu.Lock()
if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil {
// ignore error, since Spec has passed validation before
heartbeatPeriod, _ := ptypes.Duration(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod)
if heartbeatPeriod != d.config.HeartbeatPeriod {
// only call d.nodes.updatePeriod when heartbeatPeriod changes
d.config.HeartbeatPeriod = heartbeatPeriod
d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier)
}
}
d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys
d.mu.Unlock()
d.keyMgrQueue.Publish(struct{}{})
case <-d.ctx.Done():
return nil
}
}
}
// Stop stops dispatcher and closes all grpc streams.
func (d *Dispatcher) Stop() error {
d.mu.Lock()
if !d.isRunning() {
d.mu.Unlock()
return fmt.Errorf("dispatcher is already stopped")
}
d.cancel()
d.mu.Unlock()
d.nodes.Clean()
return nil
}
func (d *Dispatcher) isRunningLocked() error {
d.mu.Lock()
if !d.isRunning() {
d.mu.Unlock()
return grpc.Errorf(codes.Aborted, "dispatcher is stopped")
}
d.mu.Unlock()
return nil
}
func (d *Dispatcher) markNodesUnknown(ctx context.Context) error {
log := log.G(ctx).WithField("method", "(*Dispatcher).markNodesUnknown")
var nodes []*api.Node
var err error
d.store.View(func(tx store.ReadTx) {
nodes, err = store.FindNodes(tx, store.All)
})
if err != nil {
return fmt.Errorf("failed to get list of nodes: %v", err)
}
_, err = d.store.Batch(func(batch *store.Batch) error {
for _, n := range nodes {
err := batch.Update(func(tx store.Tx) error {
// check if node is still here
node := store.GetNode(tx, n.ID)
if node == nil {
return nil
}
// do not try to resurrect down nodes
if node.Status.State == api.NodeStatus_DOWN {
return nil
}
node.Status = api.NodeStatus{
State: api.NodeStatus_UNKNOWN,
Message: `Node moved to "unknown" state due to leadership change in cluster`,
}
nodeID := node.ID
expireFunc := func() {
log := log.WithField("node", nodeID)
nodeStatus := api.NodeStatus{State: api.NodeStatus_DOWN, Message: `heartbeat failure for node in "unknown" state`}
log.Debugf("heartbeat expiration for unknown node")
if err := d.nodeRemove(nodeID, nodeStatus); err != nil {
log.WithError(err).Errorf(`failed deregistering node after heartbeat expiration for node in "unknown" state`)
}
}
if err := d.nodes.AddUnknown(node, expireFunc); err != nil {
return fmt.Errorf(`adding node in "unknown" state to node store failed: %v`, err)
}
if err := store.UpdateNode(tx, node); err != nil {
return fmt.Errorf("update failed %v", err)
}
return nil
})
if err != nil {
log.WithField("node", n.ID).WithError(err).Errorf(`failed to move node to "unknown" state`)
}
}
return nil
})
return err
}
func (d *Dispatcher) isRunning() bool {
if d.ctx == nil {
return false
}
select {
case <-d.ctx.Done():
return false
default:
}
return true
}
// register is used for registration of node with particular dispatcher.
func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, string, error) {
// prevent register until we're ready to accept it
if err := d.isRunningLocked(); err != nil {
return "", "", err
}
if err := d.nodes.CheckRateLimit(nodeID); err != nil {
return "", "", err
}
// create or update node in store
// TODO(stevvooe): Validate node specification.
var node *api.Node
err := d.store.Update(func(tx store.Tx) error {
node = store.GetNode(tx, nodeID)
if node == nil {
return ErrNodeNotFound
}
node.Description = description
node.Status = api.NodeStatus{
State: api.NodeStatus_READY,
}
return store.UpdateNode(tx, node)
})
if err != nil {
return "", "", err
}
expireFunc := func() {
nodeStatus := api.NodeStatus{State: api.NodeStatus_DOWN, Message: "heartbeat failure"}
log.G(ctx).Debugf("heartbeat expiration")
if err := d.nodeRemove(nodeID, nodeStatus); err != nil {
log.G(ctx).WithError(err).Errorf("failed deregistering node after heartbeat expiration")
}
}
rn := d.nodes.Add(node, expireFunc)
// NOTE(stevvooe): We need be a little careful with re-registration. The
// current implementation just matches the node id and then gives away the
// sessionID. If we ever want to use sessionID as a secret, which we may
// want to, this is giving away the keys to the kitchen.
//
// The right behavior is going to be informed by identity. Basically, each
// time a node registers, we invalidate the session and issue a new
// session, once identity is proven. This will cause misbehaved agents to
// be kicked when multiple connections are made.
return rn.Node.ID, rn.SessionID, nil
}
// UpdateTaskStatus updates status of task. Node should send such updates
// on every status change of its tasks.
func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return nil, err
}
nodeID := nodeInfo.NodeID
fields := logrus.Fields{
"node.id": nodeID,
"node.session": r.SessionID,
"method": "(*Dispatcher).UpdateTaskStatus",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log := log.G(ctx).WithFields(fields)
if err := d.isRunningLocked(); err != nil {
return nil, err
}
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return nil, err
}
// Validate task updates
for _, u := range r.Updates {
if u.Status == nil {
log.WithField("task.id", u.TaskID).Warn("task report has nil status")
continue
}
var t *api.Task
d.store.View(func(tx store.ReadTx) {
t = store.GetTask(tx, u.TaskID)
})
if t == nil {
log.WithField("task.id", u.TaskID).Warn("cannot find target task in store")
continue
}
if t.NodeID != nodeID {
err := grpc.Errorf(codes.PermissionDenied, "cannot update a task not assigned this node")
log.WithField("task.id", u.TaskID).Error(err)
return nil, err
}
}
d.taskUpdatesLock.Lock()
// Enqueue task updates
for _, u := range r.Updates {
if u.Status == nil {
continue
}
d.taskUpdates[u.TaskID] = u.Status
}
numUpdates := len(d.taskUpdates)
d.taskUpdatesLock.Unlock()
if numUpdates >= maxBatchItems {
d.processTaskUpdatesTrigger <- struct{}{}
}
return nil, nil
}
func (d *Dispatcher) processTaskUpdates() {
d.taskUpdatesLock.Lock()
if len(d.taskUpdates) == 0 {
d.taskUpdatesLock.Unlock()
return
}
taskUpdates := d.taskUpdates
d.taskUpdates = make(map[string]*api.TaskStatus)
d.taskUpdatesLock.Unlock()
log := log.G(d.ctx).WithFields(logrus.Fields{
"method": "(*Dispatcher).processTaskUpdates",
})
_, err := d.store.Batch(func(batch *store.Batch) error {
for taskID, status := range taskUpdates {
err := batch.Update(func(tx store.Tx) error {
logger := log.WithField("task.id", taskID)
task := store.GetTask(tx, taskID)
if task == nil {
logger.Errorf("task unavailable")
return nil
}
logger = logger.WithField("state.transition", fmt.Sprintf("%v->%v", task.Status.State, status.State))
if task.Status == *status {
logger.Debug("task status identical, ignoring")
return nil
}
if task.Status.State > status.State {
logger.Debug("task status invalid transition")
return nil
}
task.Status = *status
if err := store.UpdateTask(tx, task); err != nil {
logger.WithError(err).Error("failed to update task status")
return nil
}
logger.Debug("task status updated")
return nil
})
if err != nil {
log.WithError(err).Error("dispatcher transaction failed")
}
}
return nil
})
if err != nil {
log.WithError(err).Error("dispatcher batch failed")
}
}
// Tasks is a stream of tasks state for node. Each message contains full list
// of tasks which should be run on node, if task is not present in that list,
// it should be terminated.
func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error {
nodeInfo, err := ca.RemoteNode(stream.Context())
if err != nil {
return err
}
nodeID := nodeInfo.NodeID
if err := d.isRunningLocked(); err != nil {
return err
}
fields := logrus.Fields{
"node.id": nodeID,
"node.session": r.SessionID,
"method": "(*Dispatcher).Tasks",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log.G(stream.Context()).WithFields(fields).Debugf("")
if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return err
}
tasksMap := make(map[string]*api.Task)
nodeTasks, cancel, err := store.ViewAndWatch(
d.store,
func(readTx store.ReadTx) error {
tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
if err != nil {
return err
}
for _, t := range tasks {
tasksMap[t.ID] = t
}
return nil
},
state.EventCreateTask{Task: &api.Task{NodeID: nodeID},
Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
state.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
state.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
Checks: []state.TaskCheckFunc{state.TaskCheckNodeID}},
)
if err != nil {
return err
}
defer cancel()
for {
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return err
}
var tasks []*api.Task
for _, t := range tasksMap {
// dispatcher only sends tasks that have been assigned to a node
if t != nil && t.Status.State >= api.TaskStateAssigned {
tasks = append(tasks, t)
}
}
if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil {
return err
}
select {
case event := <-nodeTasks:
switch v := event.(type) {
case state.EventCreateTask:
tasksMap[v.Task.ID] = v.Task
case state.EventUpdateTask:
tasksMap[v.Task.ID] = v.Task
case state.EventDeleteTask:
delete(tasksMap, v.Task.ID)
}
case <-stream.Context().Done():
return stream.Context().Err()
case <-d.ctx.Done():
return d.ctx.Err()
}
}
}
func (d *Dispatcher) nodeRemove(id string, status api.NodeStatus) error {
if err := d.isRunningLocked(); err != nil {
return err
}
// TODO(aaronl): Is it worth batching node removals?
err := d.store.Update(func(tx store.Tx) error {
node := store.GetNode(tx, id)
if node == nil {
return errors.New("node not found")
}
node.Status = status
return store.UpdateNode(tx, node)
})
if err != nil {
return fmt.Errorf("failed to update node %s status to down: %v", id, err)
}
if rn := d.nodes.Delete(id); rn == nil {
return fmt.Errorf("node %s is not found in local storage", id)
}
return nil
}
// Heartbeat is heartbeat method for nodes. It returns new TTL in response.
// Node should send new heartbeat earlier than now + TTL, otherwise it will
// be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return nil, err
}
period, err := d.nodes.Heartbeat(nodeInfo.NodeID, r.SessionID)
return &api.HeartbeatResponse{Period: *ptypes.DurationProto(period)}, err
}
func (d *Dispatcher) getManagers() []*api.WeightedPeer {
d.mu.Lock()
defer d.mu.Unlock()
return d.lastSeenManagers
}
// Session is a stream which controls agent connection.
// Each message contains list of backup Managers with weights. Also there is
// a special boolean field Disconnect which if true indicates that node should
// reconnect to another Manager immediately.
func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error {
ctx := stream.Context()
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return err
}
nodeID := nodeInfo.NodeID
if err := d.isRunningLocked(); err != nil {
return err
}
// register the node.
nodeID, sessionID, err := d.register(stream.Context(), nodeID, r.Description)
if err != nil {
return err
}
fields := logrus.Fields{
"node.id": nodeID,
"node.session": sessionID,
"method": "(*Dispatcher).Session",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log := log.G(ctx).WithFields(fields)
var nodeObj *api.Node
nodeUpdates, cancel, err := store.ViewAndWatch(d.store, func(readTx store.ReadTx) error {
nodeObj = store.GetNode(readTx, nodeID)
return nil
}, state.EventUpdateNode{Node: &api.Node{ID: nodeID},
Checks: []state.NodeCheckFunc{state.NodeCheckID}},
)
if cancel != nil {
defer cancel()
}
if err != nil {
log.WithError(err).Error("ViewAndWatch Node failed")
}
if _, err = d.nodes.GetWithSession(nodeID, sessionID); err != nil {
return err
}
if err := stream.Send(&api.SessionMessage{
SessionID: sessionID,
Node: nodeObj,
Managers: d.getManagers(),
NetworkBootstrapKeys: d.networkBootstrapKeys,
}); err != nil {
return err
}
managerUpdates, mgrCancel := d.mgrQueue.Watch()
defer mgrCancel()
keyMgrUpdates, keyMgrCancel := d.keyMgrQueue.Watch()
defer keyMgrCancel()
// disconnectNode is a helper forcibly shutdown connection
disconnectNode := func() error {
// force disconnect by shutting down the stream.
transportStream, ok := transport.StreamFromContext(stream.Context())
if ok {
// if we have the transport stream, we can signal a disconnect
// in the client.
if err := transportStream.ServerTransport().Close(); err != nil {
log.WithError(err).Error("session end")
}
}
nodeStatus := api.NodeStatus{State: api.NodeStatus_DISCONNECTED, Message: "node is currently trying to find new manager"}
if err := d.nodeRemove(nodeID, nodeStatus); err != nil {
log.WithError(err).Error("failed to remove node")
}
// still return an abort if the transport closure was ineffective.
return grpc.Errorf(codes.Aborted, "node must disconnect")
}
for {
// After each message send, we need to check the nodes sessionID hasn't
// changed. If it has, we will the stream and make the node
// re-register.
node, err := d.nodes.GetWithSession(nodeID, sessionID)
if err != nil {
return err
}
var mgrs []*api.WeightedPeer
var disconnect bool
select {
case ev := <-managerUpdates:
mgrs = ev.([]*api.WeightedPeer)
case ev := <-nodeUpdates:
nodeObj = ev.(state.EventUpdateNode).Node
case <-stream.Context().Done():
return stream.Context().Err()
case <-node.Disconnect:
disconnect = true
case <-d.ctx.Done():
disconnect = true
case <-keyMgrUpdates:
}
if mgrs == nil {
mgrs = d.getManagers()
}
if err := stream.Send(&api.SessionMessage{
SessionID: sessionID,
Node: nodeObj,
Managers: mgrs,
NetworkBootstrapKeys: d.networkBootstrapKeys,
}); err != nil {
return err
}
if disconnect {
return disconnectNode()
}
}
}
// NodeCount returns number of nodes which connected to this dispatcher.
func (d *Dispatcher) NodeCount() int {
return d.nodes.Len()
}