From 48d387a757570245b4549cc46d96017630baf32f Mon Sep 17 00:00:00 2001 From: Kazuyoshi Kato Date: Mon, 22 Nov 2021 09:36:11 -0800 Subject: [PATCH] daemon/logger: read the length header correctly Before this change, if Decode() couldn't read a log record fully, the subsequent invocation of Decode() would read the record's non-header part as a header and cause a huge heap allocation. This change prevents such a case by having the intermediate buffer in the decoder struct. Fixes #42125. Signed-off-by: Kazuyoshi Kato --- daemon/logger/local/read.go | 121 ++++++++++++++++----------- daemon/logger/local/read_test.go | 52 ++++++++++++ daemon/logger/loggerutils/logfile.go | 1 + 3 files changed, 125 insertions(+), 49 deletions(-) create mode 100644 daemon/logger/local/read_test.go diff --git a/daemon/logger/local/read.go b/daemon/logger/local/read.go index b28aa90651..d517995cff 100644 --- a/daemon/logger/local/read.go +++ b/daemon/logger/local/read.go @@ -1,12 +1,12 @@ package local import ( + "bytes" "context" "encoding/binary" + "fmt" "io" - "bytes" - "github.com/docker/docker/api/types/plugins/logdriver" "github.com/docker/docker/daemon/logger" "github.com/docker/docker/daemon/logger/loggerutils" @@ -14,6 +14,10 @@ import ( "github.com/pkg/errors" ) +// maxMsgLen is the maximum size of the logger.Message after serialization. +// logger.defaultBufSize caps the size of Line field. +const maxMsgLen int = 1e6 // 1MB. + func (d *driver) ReadLogs(config logger.ReadConfig) *logger.LogWatcher { logWatcher := logger.NewLogWatcher() @@ -99,7 +103,35 @@ func getTailReader(ctx context.Context, r loggerutils.SizeReaderAt, req int) (io type decoder struct { rdr io.Reader proto *logdriver.LogEntry - buf []byte + // buf keeps bytes from rdr. + buf []byte + // offset is the position in buf. + // If offset > 0, buf[offset:] has bytes which are read but haven't used. + offset int + // nextMsgLen is the length of the next log message. + // If nextMsgLen = 0, a new value must be read from rdr. + nextMsgLen int +} + +func (d *decoder) readRecord(size int) error { + var err error + for i := 0; i < maxDecodeRetry; i++ { + var n int + n, err = io.ReadFull(d.rdr, d.buf[d.offset:size]) + d.offset += n + if err != nil { + if err != io.ErrUnexpectedEOF { + return err + } + continue + } + break + } + if err != nil { + return err + } + d.offset = 0 + return nil } func (d *decoder) Decode() (*logger.Message, error) { @@ -111,44 +143,35 @@ func (d *decoder) Decode() (*logger.Message, error) { if d.buf == nil { d.buf = make([]byte, initialBufSize) } - var ( - read int - err error - ) - for i := 0; i < maxDecodeRetry; i++ { - var n int - n, err = io.ReadFull(d.rdr, d.buf[read:encodeBinaryLen]) + if d.nextMsgLen == 0 { + msgLen, err := d.decodeSizeHeader() if err != nil { - if err != io.ErrUnexpectedEOF { - return nil, errors.Wrap(err, "error reading log message length") - } - read += n - continue + return nil, err } - read += n - break - } - if err != nil { - return nil, errors.Wrapf(err, "could not read log message length: read: %d, expected: %d", read, encodeBinaryLen) - } - msgLen := int(binary.BigEndian.Uint32(d.buf[:read])) + if msgLen > maxMsgLen { + return nil, fmt.Errorf("log message is too large (%d > %d)", msgLen, maxMsgLen) + } - if len(d.buf) < msgLen+encodeBinaryLen { - d.buf = make([]byte, msgLen+encodeBinaryLen) - } else { - if msgLen <= initialBufSize { + if len(d.buf) < msgLen+encodeBinaryLen { + d.buf = make([]byte, msgLen+encodeBinaryLen) + } else if msgLen <= initialBufSize { d.buf = d.buf[:initialBufSize] } else { d.buf = d.buf[:msgLen+encodeBinaryLen] } - } - return decodeLogEntry(d.rdr, d.proto, d.buf, msgLen) + d.nextMsgLen = msgLen + } + return d.decodeLogEntry() } func (d *decoder) Reset(rdr io.Reader) { + if d.rdr == rdr { + return + } + d.rdr = rdr if d.proto != nil { resetProto(d.proto) @@ -156,6 +179,8 @@ func (d *decoder) Reset(rdr io.Reader) { if d.buf != nil { d.buf = d.buf[:initialBufSize] } + d.offset = 0 + d.nextMsgLen = 0 } func (d *decoder) Close() { @@ -171,34 +196,32 @@ func decodeFunc(rdr io.Reader) loggerutils.Decoder { return &decoder{rdr: rdr} } -func decodeLogEntry(rdr io.Reader, proto *logdriver.LogEntry, buf []byte, msgLen int) (*logger.Message, error) { - var ( - read int - err error - ) - for i := 0; i < maxDecodeRetry; i++ { - var n int - n, err = io.ReadFull(rdr, buf[read:msgLen+encodeBinaryLen]) - if err != nil { - if err != io.ErrUnexpectedEOF { - return nil, errors.Wrap(err, "could not decode log entry") - } - read += n - continue - } - break - } +func (d *decoder) decodeSizeHeader() (int, error) { + err := d.readRecord(encodeBinaryLen) if err != nil { - return nil, errors.Wrapf(err, "could not decode entry: read %d, expected: %d", read, msgLen) + return 0, errors.Wrap(err, "could not read a size header") } - if err := proto.Unmarshal(buf[:msgLen]); err != nil { - return nil, errors.Wrap(err, "error unmarshalling log entry") + msgLen := int(binary.BigEndian.Uint32(d.buf[:encodeBinaryLen])) + return msgLen, nil +} + +func (d *decoder) decodeLogEntry() (*logger.Message, error) { + msgLen := d.nextMsgLen + err := d.readRecord(msgLen + encodeBinaryLen) + if err != nil { + return nil, errors.Wrapf(err, "could not read a log entry (size=%d+%d)", msgLen, encodeBinaryLen) + } + d.nextMsgLen = 0 + + if err := d.proto.Unmarshal(d.buf[:msgLen]); err != nil { + return nil, errors.Wrapf(err, "error unmarshalling log entry (size=%d)", msgLen) } - msg := protoToMessage(proto) + msg := protoToMessage(d.proto) if msg.PLogMetaData == nil { msg.Line = append(msg.Line, '\n') } + return msg, nil } diff --git a/daemon/logger/local/read_test.go b/daemon/logger/local/read_test.go new file mode 100644 index 0000000000..21d8603649 --- /dev/null +++ b/daemon/logger/local/read_test.go @@ -0,0 +1,52 @@ +package local + +import ( + "io" + "io/ioutil" + "os" + "testing" + + "github.com/docker/docker/daemon/logger" + "github.com/pkg/errors" + "gotest.tools/v3/assert" +) + +func TestDecode(t *testing.T) { + marshal := makeMarshaller() + + buf, err := marshal(&logger.Message{Line: []byte("hello")}) + assert.NilError(t, err) + + for i := 0; i < len(buf); i++ { + testDecode(t, buf, i) + } +} + +func testDecode(t *testing.T, buf []byte, split int) { + fw, err := ioutil.TempFile("", t.Name()) + assert.NilError(t, err) + defer os.Remove(fw.Name()) + + fr, err := os.Open(fw.Name()) + assert.NilError(t, err) + + d := &decoder{rdr: fr} + + if split > 0 { + _, err = fw.Write(buf[0:split]) + assert.NilError(t, err) + + _, err = d.Decode() + assert.Assert(t, errors.Is(err, io.EOF)) + + _, err = fw.Write(buf[split:]) + assert.NilError(t, err) + } else { + _, err = fw.Write(buf) + assert.NilError(t, err) + } + + message, err := d.Decode() + assert.NilError(t, err) + assert.Equal(t, "hello\n", string(message.Line)) +} diff --git a/daemon/logger/loggerutils/logfile.go b/daemon/logger/loggerutils/logfile.go index 6b42d9dd30..8817934946 100644 --- a/daemon/logger/loggerutils/logfile.go +++ b/daemon/logger/loggerutils/logfile.go @@ -715,6 +715,7 @@ func followLogs(f *os.File, logWatcher *logger.LogWatcher, notifyRotate, notifyE defer func() { oldSize = size }() if size < oldSize { // truncated f.Seek(0, 0) + dec.Reset(f) return nil } } else {