1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

daemon/logger: read the length header correctly

Before this change, if Decode() couldn't read a log record fully,
the subsequent invocation of Decode() would read the record's non-header part
as a header and cause a huge heap allocation.

This change prevents such a case by having the intermediate buffer in
the decoder struct.

Fixes #42125.

Signed-off-by: Kazuyoshi Kato <katokazu@amazon.com>
This commit is contained in:
Kazuyoshi Kato 2021-11-22 09:36:11 -08:00
parent 8955d8da89
commit 48d387a757
3 changed files with 125 additions and 49 deletions

View file

@ -1,12 +1,12 @@
package local
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"bytes"
"github.com/docker/docker/api/types/plugins/logdriver"
"github.com/docker/docker/daemon/logger"
"github.com/docker/docker/daemon/logger/loggerutils"
@ -14,6 +14,10 @@ import (
"github.com/pkg/errors"
)
// maxMsgLen is the maximum size of the logger.Message after serialization.
// logger.defaultBufSize caps the size of Line field.
const maxMsgLen int = 1e6 // 1MB.
func (d *driver) ReadLogs(config logger.ReadConfig) *logger.LogWatcher {
logWatcher := logger.NewLogWatcher()
@ -99,7 +103,35 @@ func getTailReader(ctx context.Context, r loggerutils.SizeReaderAt, req int) (io
type decoder struct {
rdr io.Reader
proto *logdriver.LogEntry
buf []byte
// buf keeps bytes from rdr.
buf []byte
// offset is the position in buf.
// If offset > 0, buf[offset:] has bytes which are read but haven't used.
offset int
// nextMsgLen is the length of the next log message.
// If nextMsgLen = 0, a new value must be read from rdr.
nextMsgLen int
}
func (d *decoder) readRecord(size int) error {
var err error
for i := 0; i < maxDecodeRetry; i++ {
var n int
n, err = io.ReadFull(d.rdr, d.buf[d.offset:size])
d.offset += n
if err != nil {
if err != io.ErrUnexpectedEOF {
return err
}
continue
}
break
}
if err != nil {
return err
}
d.offset = 0
return nil
}
func (d *decoder) Decode() (*logger.Message, error) {
@ -111,44 +143,35 @@ func (d *decoder) Decode() (*logger.Message, error) {
if d.buf == nil {
d.buf = make([]byte, initialBufSize)
}
var (
read int
err error
)
for i := 0; i < maxDecodeRetry; i++ {
var n int
n, err = io.ReadFull(d.rdr, d.buf[read:encodeBinaryLen])
if d.nextMsgLen == 0 {
msgLen, err := d.decodeSizeHeader()
if err != nil {
if err != io.ErrUnexpectedEOF {
return nil, errors.Wrap(err, "error reading log message length")
}
read += n
continue
return nil, err
}
read += n
break
}
if err != nil {
return nil, errors.Wrapf(err, "could not read log message length: read: %d, expected: %d", read, encodeBinaryLen)
}
msgLen := int(binary.BigEndian.Uint32(d.buf[:read]))
if msgLen > maxMsgLen {
return nil, fmt.Errorf("log message is too large (%d > %d)", msgLen, maxMsgLen)
}
if len(d.buf) < msgLen+encodeBinaryLen {
d.buf = make([]byte, msgLen+encodeBinaryLen)
} else {
if msgLen <= initialBufSize {
if len(d.buf) < msgLen+encodeBinaryLen {
d.buf = make([]byte, msgLen+encodeBinaryLen)
} else if msgLen <= initialBufSize {
d.buf = d.buf[:initialBufSize]
} else {
d.buf = d.buf[:msgLen+encodeBinaryLen]
}
}
return decodeLogEntry(d.rdr, d.proto, d.buf, msgLen)
d.nextMsgLen = msgLen
}
return d.decodeLogEntry()
}
func (d *decoder) Reset(rdr io.Reader) {
if d.rdr == rdr {
return
}
d.rdr = rdr
if d.proto != nil {
resetProto(d.proto)
@ -156,6 +179,8 @@ func (d *decoder) Reset(rdr io.Reader) {
if d.buf != nil {
d.buf = d.buf[:initialBufSize]
}
d.offset = 0
d.nextMsgLen = 0
}
func (d *decoder) Close() {
@ -171,34 +196,32 @@ func decodeFunc(rdr io.Reader) loggerutils.Decoder {
return &decoder{rdr: rdr}
}
func decodeLogEntry(rdr io.Reader, proto *logdriver.LogEntry, buf []byte, msgLen int) (*logger.Message, error) {
var (
read int
err error
)
for i := 0; i < maxDecodeRetry; i++ {
var n int
n, err = io.ReadFull(rdr, buf[read:msgLen+encodeBinaryLen])
if err != nil {
if err != io.ErrUnexpectedEOF {
return nil, errors.Wrap(err, "could not decode log entry")
}
read += n
continue
}
break
}
func (d *decoder) decodeSizeHeader() (int, error) {
err := d.readRecord(encodeBinaryLen)
if err != nil {
return nil, errors.Wrapf(err, "could not decode entry: read %d, expected: %d", read, msgLen)
return 0, errors.Wrap(err, "could not read a size header")
}
if err := proto.Unmarshal(buf[:msgLen]); err != nil {
return nil, errors.Wrap(err, "error unmarshalling log entry")
msgLen := int(binary.BigEndian.Uint32(d.buf[:encodeBinaryLen]))
return msgLen, nil
}
func (d *decoder) decodeLogEntry() (*logger.Message, error) {
msgLen := d.nextMsgLen
err := d.readRecord(msgLen + encodeBinaryLen)
if err != nil {
return nil, errors.Wrapf(err, "could not read a log entry (size=%d+%d)", msgLen, encodeBinaryLen)
}
d.nextMsgLen = 0
if err := d.proto.Unmarshal(d.buf[:msgLen]); err != nil {
return nil, errors.Wrapf(err, "error unmarshalling log entry (size=%d)", msgLen)
}
msg := protoToMessage(proto)
msg := protoToMessage(d.proto)
if msg.PLogMetaData == nil {
msg.Line = append(msg.Line, '\n')
}
return msg, nil
}

View file

@ -0,0 +1,52 @@
package local
import (
"io"
"io/ioutil"
"os"
"testing"
"github.com/docker/docker/daemon/logger"
"github.com/pkg/errors"
"gotest.tools/v3/assert"
)
func TestDecode(t *testing.T) {
marshal := makeMarshaller()
buf, err := marshal(&logger.Message{Line: []byte("hello")})
assert.NilError(t, err)
for i := 0; i < len(buf); i++ {
testDecode(t, buf, i)
}
}
func testDecode(t *testing.T, buf []byte, split int) {
fw, err := ioutil.TempFile("", t.Name())
assert.NilError(t, err)
defer os.Remove(fw.Name())
fr, err := os.Open(fw.Name())
assert.NilError(t, err)
d := &decoder{rdr: fr}
if split > 0 {
_, err = fw.Write(buf[0:split])
assert.NilError(t, err)
_, err = d.Decode()
assert.Assert(t, errors.Is(err, io.EOF))
_, err = fw.Write(buf[split:])
assert.NilError(t, err)
} else {
_, err = fw.Write(buf)
assert.NilError(t, err)
}
message, err := d.Decode()
assert.NilError(t, err)
assert.Equal(t, "hello\n", string(message.Line))
}

View file

@ -715,6 +715,7 @@ func followLogs(f *os.File, logWatcher *logger.LogWatcher, notifyRotate, notifyE
defer func() { oldSize = size }()
if size < oldSize { // truncated
f.Seek(0, 0)
dec.Reset(f)
return nil
}
} else {