1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

Set timeout on splunk batch send

Before this change, if the splunk endpoint is blocked it will cause a
deadlock on `Close()`.
This sets a reasonable timeout for the http request to send a log batch.

Signed-off-by: Brian Goff <cpuguy83@gmail.com>
This commit is contained in:
Brian Goff 2017-11-14 10:15:38 -05:00
parent 0defc69813
commit 24087399d9
3 changed files with 87 additions and 6 deletions

View file

@ -5,6 +5,7 @@ package splunk
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
@ -63,6 +64,8 @@ const (
envVarStreamChannelSize = "SPLUNK_LOGGING_DRIVER_CHANNEL_SIZE" envVarStreamChannelSize = "SPLUNK_LOGGING_DRIVER_CHANNEL_SIZE"
) )
var batchSendTimeout = 30 * time.Second
type splunkLoggerInterface interface { type splunkLoggerInterface interface {
logger.Logger logger.Logger
worker() worker()
@ -416,13 +419,18 @@ func (l *splunkLogger) worker() {
func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool) []*splunkMessage { func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool) []*splunkMessage {
messagesLen := len(messages) messagesLen := len(messages)
ctx, cancel := context.WithTimeout(context.Background(), batchSendTimeout)
defer cancel()
for i := 0; i < messagesLen; i += l.postMessagesBatchSize { for i := 0; i < messagesLen; i += l.postMessagesBatchSize {
upperBound := i + l.postMessagesBatchSize upperBound := i + l.postMessagesBatchSize
if upperBound > messagesLen { if upperBound > messagesLen {
upperBound = messagesLen upperBound = messagesLen
} }
if err := l.tryPostMessages(messages[i:upperBound]); err != nil {
logrus.Error(err) if err := l.tryPostMessages(ctx, messages[i:upperBound]); err != nil {
logrus.WithError(err).WithField("module", "logger/splunk").Warn("Error while sending logs")
if messagesLen-i >= l.bufferMaximum || lastChance { if messagesLen-i >= l.bufferMaximum || lastChance {
// If this is last chance - print them all to the daemon log // If this is last chance - print them all to the daemon log
if lastChance { if lastChance {
@ -447,7 +455,7 @@ func (l *splunkLogger) postMessages(messages []*splunkMessage, lastChance bool)
return messages[:0] return messages[:0]
} }
func (l *splunkLogger) tryPostMessages(messages []*splunkMessage) error { func (l *splunkLogger) tryPostMessages(ctx context.Context, messages []*splunkMessage) error {
if len(messages) == 0 { if len(messages) == 0 {
return nil return nil
} }
@ -486,6 +494,7 @@ func (l *splunkLogger) tryPostMessages(messages []*splunkMessage) error {
if err != nil { if err != nil {
return err return err
} }
req = req.WithContext(ctx)
req.Header.Set("Authorization", l.auth) req.Header.Set("Authorization", l.auth)
// Tell if we are sending gzip compressed body // Tell if we are sending gzip compressed body
if l.gzipCompression { if l.gzipCompression {

View file

@ -2,8 +2,10 @@ package splunk
import ( import (
"compress/gzip" "compress/gzip"
"context"
"fmt" "fmt"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
@ -1062,7 +1064,7 @@ func TestSkipVerify(t *testing.T) {
t.Fatal("No messages should be accepted at this point") t.Fatal("No messages should be accepted at this point")
} }
hec.simulateServerError = false hec.simulateErr(false)
for i := defaultStreamChannelSize * 2; i < defaultStreamChannelSize*4; i++ { for i := defaultStreamChannelSize * 2; i < defaultStreamChannelSize*4; i++ {
if err := loggerDriver.Log(&logger.Message{Line: []byte(fmt.Sprintf("%d", i)), Source: "stdout", Timestamp: time.Now()}); err != nil { if err := loggerDriver.Log(&logger.Message{Line: []byte(fmt.Sprintf("%d", i)), Source: "stdout", Timestamp: time.Now()}); err != nil {
@ -1110,7 +1112,7 @@ func TestBufferMaximum(t *testing.T) {
} }
hec := NewHTTPEventCollectorMock(t) hec := NewHTTPEventCollectorMock(t)
hec.simulateServerError = true hec.simulateErr(true)
go hec.Serve() go hec.Serve()
info := logger.Info{ info := logger.Info{
@ -1308,3 +1310,48 @@ func TestCannotSendAfterClose(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
func TestDeadlockOnBlockedEndpoint(t *testing.T) {
hec := NewHTTPEventCollectorMock(t)
go hec.Serve()
info := logger.Info{
Config: map[string]string{
splunkURLKey: hec.URL(),
splunkTokenKey: hec.token,
},
ContainerID: "containeriid",
ContainerName: "/container_name",
ContainerImageID: "contaimageid",
ContainerImageName: "container_image_name",
}
l, err := New(info)
if err != nil {
t.Fatal(err)
}
ctx, unblock := context.WithCancel(context.Background())
hec.withBlock(ctx)
defer unblock()
batchSendTimeout = 1 * time.Second
if err := l.Log(&logger.Message{}); err != nil {
t.Fatal(err)
}
done := make(chan struct{})
go func() {
l.Close()
close(done)
}()
select {
case <-time.After(60 * time.Second):
buf := make([]byte, 1e6)
buf = buf[:runtime.Stack(buf, true)]
t.Logf("STACK DUMP: \n\n%s\n\n", string(buf))
t.Fatal("timeout waiting for close to finish")
case <-done:
}
}

View file

@ -2,12 +2,14 @@ package splunk
import ( import (
"compress/gzip" "compress/gzip"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"sync"
"testing" "testing"
) )
@ -29,8 +31,10 @@ type HTTPEventCollectorMock struct {
tcpAddr *net.TCPAddr tcpAddr *net.TCPAddr
tcpListener *net.TCPListener tcpListener *net.TCPListener
mu sync.Mutex
token string token string
simulateServerError bool simulateServerError bool
blockingCtx context.Context
test *testing.T test *testing.T
@ -55,6 +59,18 @@ func NewHTTPEventCollectorMock(t *testing.T) *HTTPEventCollectorMock {
connectionVerified: false} connectionVerified: false}
} }
func (hec *HTTPEventCollectorMock) simulateErr(b bool) {
hec.mu.Lock()
hec.simulateServerError = b
hec.mu.Unlock()
}
func (hec *HTTPEventCollectorMock) withBlock(ctx context.Context) {
hec.mu.Lock()
hec.blockingCtx = ctx
hec.mu.Unlock()
}
func (hec *HTTPEventCollectorMock) URL() string { func (hec *HTTPEventCollectorMock) URL() string {
return "http://" + hec.tcpListener.Addr().String() return "http://" + hec.tcpListener.Addr().String()
} }
@ -72,7 +88,16 @@ func (hec *HTTPEventCollectorMock) ServeHTTP(writer http.ResponseWriter, request
hec.numOfRequests++ hec.numOfRequests++
if hec.simulateServerError { hec.mu.Lock()
simErr := hec.simulateServerError
ctx := hec.blockingCtx
hec.mu.Unlock()
if ctx != nil {
<-hec.blockingCtx.Done()
}
if simErr {
if request.Body != nil { if request.Body != nil {
defer request.Body.Close() defer request.Body.Close()
} }