package ttrpc import ( "bufio" "context" "encoding/binary" "io" "sync" "github.com/pkg/errors" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) const ( messageHeaderLength = 10 messageLengthMax = 4 << 20 ) type messageType uint8 const ( messageTypeRequest messageType = 0x1 messageTypeResponse messageType = 0x2 ) // messageHeader represents the fixed-length message header of 10 bytes sent // with every request. type messageHeader struct { Length uint32 // length excluding this header. b[:4] StreamID uint32 // identifies which request stream message is a part of. b[4:8] Type messageType // message type b[8] Flags uint8 // reserved b[9] } func readMessageHeader(p []byte, r io.Reader) (messageHeader, error) { _, err := io.ReadFull(r, p[:messageHeaderLength]) if err != nil { return messageHeader{}, err } return messageHeader{ Length: binary.BigEndian.Uint32(p[:4]), StreamID: binary.BigEndian.Uint32(p[4:8]), Type: messageType(p[8]), Flags: p[9], }, nil } func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error { binary.BigEndian.PutUint32(p[:4], mh.Length) binary.BigEndian.PutUint32(p[4:8], mh.StreamID) p[8] = byte(mh.Type) p[9] = mh.Flags _, err := w.Write(p[:]) return err } var buffers sync.Pool type channel struct { bw *bufio.Writer br *bufio.Reader hrbuf [messageHeaderLength]byte // avoid alloc when reading header hwbuf [messageHeaderLength]byte } func newChannel(w io.Writer, r io.Reader) *channel { return &channel{ bw: bufio.NewWriter(w), br: bufio.NewReader(r), } } // recv a message from the channel. The returned buffer contains the message. // // If a valid grpc status is returned, the message header // returned will be valid and caller should send that along to // the correct consumer. The bytes on the underlying channel // will be discarded. func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) { mh, err := readMessageHeader(ch.hrbuf[:], ch.br) if err != nil { return messageHeader{}, nil, err } if mh.Length > uint32(messageLengthMax) { if _, err := ch.br.Discard(int(mh.Length)); err != nil { return mh, nil, errors.Wrapf(err, "failed to discard after receiving oversized message") } return mh, nil, status.Errorf(codes.ResourceExhausted, "message length %v exceed maximum message size of %v", mh.Length, messageLengthMax) } p := ch.getmbuf(int(mh.Length)) if _, err := io.ReadFull(ch.br, p); err != nil { return messageHeader{}, nil, errors.Wrapf(err, "failed reading message") } return mh, p, nil } func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error { if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil { return err } _, err := ch.bw.Write(p) if err != nil { return err } return ch.bw.Flush() } func (ch *channel) getmbuf(size int) []byte { // we can't use the standard New method on pool because we want to allocate // based on size. b, ok := buffers.Get().(*[]byte) if !ok || cap(*b) < size { // TODO(stevvooe): It may be better to allocate these in fixed length // buckets to reduce fragmentation but its not clear that would help // with performance. An ilogb approach or similar would work well. bb := make([]byte, size) b = &bb } else { *b = (*b)[:size] } return *b } func (ch *channel) putmbuf(p []byte) { buffers.Put(&p) }