mirror of
https://github.com/moby/moby.git
synced 2022-11-09 12:21:53 -05:00
daemon/logger: read the length header correctly
Before this change, if Decode() couldn't read a log record fully, the subsequent invocation of Decode() would read the record's non-header part as a header and cause a huge heap allocation. This change prevents such a case by having the intermediate buffer in the decoder struct. Fixes #42125. Signed-off-by: Kazuyoshi Kato <katokazu@amazon.com>
This commit is contained in:
parent
8955d8da89
commit
48d387a757
3 changed files with 125 additions and 49 deletions
|
@ -1,12 +1,12 @@
|
|||
package local
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"bytes"
|
||||
|
||||
"github.com/docker/docker/api/types/plugins/logdriver"
|
||||
"github.com/docker/docker/daemon/logger"
|
||||
"github.com/docker/docker/daemon/logger/loggerutils"
|
||||
|
@ -14,6 +14,10 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// maxMsgLen is the maximum size of the logger.Message after serialization.
|
||||
// logger.defaultBufSize caps the size of Line field.
|
||||
const maxMsgLen int = 1e6 // 1MB.
|
||||
|
||||
func (d *driver) ReadLogs(config logger.ReadConfig) *logger.LogWatcher {
|
||||
logWatcher := logger.NewLogWatcher()
|
||||
|
||||
|
@ -99,7 +103,35 @@ func getTailReader(ctx context.Context, r loggerutils.SizeReaderAt, req int) (io
|
|||
type decoder struct {
|
||||
rdr io.Reader
|
||||
proto *logdriver.LogEntry
|
||||
buf []byte
|
||||
// buf keeps bytes from rdr.
|
||||
buf []byte
|
||||
// offset is the position in buf.
|
||||
// If offset > 0, buf[offset:] has bytes which are read but haven't used.
|
||||
offset int
|
||||
// nextMsgLen is the length of the next log message.
|
||||
// If nextMsgLen = 0, a new value must be read from rdr.
|
||||
nextMsgLen int
|
||||
}
|
||||
|
||||
func (d *decoder) readRecord(size int) error {
|
||||
var err error
|
||||
for i := 0; i < maxDecodeRetry; i++ {
|
||||
var n int
|
||||
n, err = io.ReadFull(d.rdr, d.buf[d.offset:size])
|
||||
d.offset += n
|
||||
if err != nil {
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.offset = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *decoder) Decode() (*logger.Message, error) {
|
||||
|
@ -111,44 +143,35 @@ func (d *decoder) Decode() (*logger.Message, error) {
|
|||
if d.buf == nil {
|
||||
d.buf = make([]byte, initialBufSize)
|
||||
}
|
||||
var (
|
||||
read int
|
||||
err error
|
||||
)
|
||||
|
||||
for i := 0; i < maxDecodeRetry; i++ {
|
||||
var n int
|
||||
n, err = io.ReadFull(d.rdr, d.buf[read:encodeBinaryLen])
|
||||
if d.nextMsgLen == 0 {
|
||||
msgLen, err := d.decodeSizeHeader()
|
||||
if err != nil {
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
return nil, errors.Wrap(err, "error reading log message length")
|
||||
}
|
||||
read += n
|
||||
continue
|
||||
return nil, err
|
||||
}
|
||||
read += n
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not read log message length: read: %d, expected: %d", read, encodeBinaryLen)
|
||||
}
|
||||
|
||||
msgLen := int(binary.BigEndian.Uint32(d.buf[:read]))
|
||||
if msgLen > maxMsgLen {
|
||||
return nil, fmt.Errorf("log message is too large (%d > %d)", msgLen, maxMsgLen)
|
||||
}
|
||||
|
||||
if len(d.buf) < msgLen+encodeBinaryLen {
|
||||
d.buf = make([]byte, msgLen+encodeBinaryLen)
|
||||
} else {
|
||||
if msgLen <= initialBufSize {
|
||||
if len(d.buf) < msgLen+encodeBinaryLen {
|
||||
d.buf = make([]byte, msgLen+encodeBinaryLen)
|
||||
} else if msgLen <= initialBufSize {
|
||||
d.buf = d.buf[:initialBufSize]
|
||||
} else {
|
||||
d.buf = d.buf[:msgLen+encodeBinaryLen]
|
||||
}
|
||||
}
|
||||
|
||||
return decodeLogEntry(d.rdr, d.proto, d.buf, msgLen)
|
||||
d.nextMsgLen = msgLen
|
||||
}
|
||||
return d.decodeLogEntry()
|
||||
}
|
||||
|
||||
func (d *decoder) Reset(rdr io.Reader) {
|
||||
if d.rdr == rdr {
|
||||
return
|
||||
}
|
||||
|
||||
d.rdr = rdr
|
||||
if d.proto != nil {
|
||||
resetProto(d.proto)
|
||||
|
@ -156,6 +179,8 @@ func (d *decoder) Reset(rdr io.Reader) {
|
|||
if d.buf != nil {
|
||||
d.buf = d.buf[:initialBufSize]
|
||||
}
|
||||
d.offset = 0
|
||||
d.nextMsgLen = 0
|
||||
}
|
||||
|
||||
func (d *decoder) Close() {
|
||||
|
@ -171,34 +196,32 @@ func decodeFunc(rdr io.Reader) loggerutils.Decoder {
|
|||
return &decoder{rdr: rdr}
|
||||
}
|
||||
|
||||
func decodeLogEntry(rdr io.Reader, proto *logdriver.LogEntry, buf []byte, msgLen int) (*logger.Message, error) {
|
||||
var (
|
||||
read int
|
||||
err error
|
||||
)
|
||||
for i := 0; i < maxDecodeRetry; i++ {
|
||||
var n int
|
||||
n, err = io.ReadFull(rdr, buf[read:msgLen+encodeBinaryLen])
|
||||
if err != nil {
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
return nil, errors.Wrap(err, "could not decode log entry")
|
||||
}
|
||||
read += n
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
func (d *decoder) decodeSizeHeader() (int, error) {
|
||||
err := d.readRecord(encodeBinaryLen)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not decode entry: read %d, expected: %d", read, msgLen)
|
||||
return 0, errors.Wrap(err, "could not read a size header")
|
||||
}
|
||||
|
||||
if err := proto.Unmarshal(buf[:msgLen]); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshalling log entry")
|
||||
msgLen := int(binary.BigEndian.Uint32(d.buf[:encodeBinaryLen]))
|
||||
return msgLen, nil
|
||||
}
|
||||
|
||||
func (d *decoder) decodeLogEntry() (*logger.Message, error) {
|
||||
msgLen := d.nextMsgLen
|
||||
err := d.readRecord(msgLen + encodeBinaryLen)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not read a log entry (size=%d+%d)", msgLen, encodeBinaryLen)
|
||||
}
|
||||
d.nextMsgLen = 0
|
||||
|
||||
if err := d.proto.Unmarshal(d.buf[:msgLen]); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshalling log entry (size=%d)", msgLen)
|
||||
}
|
||||
|
||||
msg := protoToMessage(proto)
|
||||
msg := protoToMessage(d.proto)
|
||||
if msg.PLogMetaData == nil {
|
||||
msg.Line = append(msg.Line, '\n')
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
|
52
daemon/logger/local/read_test.go
Normal file
52
daemon/logger/local/read_test.go
Normal file
|
@ -0,0 +1,52 @@
|
|||
package local
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/docker/docker/daemon/logger"
|
||||
"github.com/pkg/errors"
|
||||
"gotest.tools/v3/assert"
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
marshal := makeMarshaller()
|
||||
|
||||
buf, err := marshal(&logger.Message{Line: []byte("hello")})
|
||||
assert.NilError(t, err)
|
||||
|
||||
for i := 0; i < len(buf); i++ {
|
||||
testDecode(t, buf, i)
|
||||
}
|
||||
}
|
||||
|
||||
func testDecode(t *testing.T, buf []byte, split int) {
|
||||
fw, err := ioutil.TempFile("", t.Name())
|
||||
assert.NilError(t, err)
|
||||
defer os.Remove(fw.Name())
|
||||
|
||||
fr, err := os.Open(fw.Name())
|
||||
assert.NilError(t, err)
|
||||
|
||||
d := &decoder{rdr: fr}
|
||||
|
||||
if split > 0 {
|
||||
_, err = fw.Write(buf[0:split])
|
||||
assert.NilError(t, err)
|
||||
|
||||
_, err = d.Decode()
|
||||
assert.Assert(t, errors.Is(err, io.EOF))
|
||||
|
||||
_, err = fw.Write(buf[split:])
|
||||
assert.NilError(t, err)
|
||||
} else {
|
||||
_, err = fw.Write(buf)
|
||||
assert.NilError(t, err)
|
||||
}
|
||||
|
||||
message, err := d.Decode()
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, "hello\n", string(message.Line))
|
||||
}
|
|
@ -715,6 +715,7 @@ func followLogs(f *os.File, logWatcher *logger.LogWatcher, notifyRotate, notifyE
|
|||
defer func() { oldSize = size }()
|
||||
if size < oldSize { // truncated
|
||||
f.Seek(0, 0)
|
||||
dec.Reset(f)
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
|
|
Loading…
Add table
Reference in a new issue