package ttrpc import ( "context" "net" "sync" "github.com/containerd/containerd/log" "github.com/gogo/protobuf/proto" "github.com/pkg/errors" "google.golang.org/grpc/status" ) type Client struct { codec codec conn net.Conn channel *channel calls chan *callRequest closed chan struct{} closeOnce sync.Once done chan struct{} err error } func NewClient(conn net.Conn) *Client { c := &Client{ codec: codec{}, conn: conn, channel: newChannel(conn, conn), calls: make(chan *callRequest), closed: make(chan struct{}), done: make(chan struct{}), } go c.run() return c } type callRequest struct { ctx context.Context req *Request resp *Response // response will be written back here errs chan error // error written here on completion } func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error { payload, err := c.codec.Marshal(req) if err != nil { return err } var ( creq = &Request{ Service: service, Method: method, Payload: payload, } cresp = &Response{} ) if err := c.dispatch(ctx, creq, cresp); err != nil { return err } if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil { return err } if cresp.Status == nil { return errors.New("no status provided on response") } return status.ErrorProto(cresp.Status) } func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { errs := make(chan error, 1) call := &callRequest{ req: req, resp: resp, errs: errs, } select { case c.calls <- call: case <-c.done: return c.err } select { case err := <-errs: return err case <-c.done: return c.err } } func (c *Client) Close() error { c.closeOnce.Do(func() { close(c.closed) }) return nil } type message struct { messageHeader p []byte err error } func (c *Client) run() { var ( streamID uint32 = 1 waiters = make(map[uint32]*callRequest) calls = c.calls incoming = make(chan *message) shutdown = make(chan struct{}) shutdownErr error ) go func() { defer close(shutdown) // start one more goroutine to recv messages without blocking. for { mh, p, err := c.channel.recv(context.TODO()) if err != nil { _, ok := status.FromError(err) if !ok { // treat all errors that are not an rpc status as terminal. // all others poison the connection. shutdownErr = err return } } select { case incoming <- &message{ messageHeader: mh, p: p[:mh.Length], err: err, }: case <-c.done: return } } }() defer c.conn.Close() defer close(c.done) for { select { case call := <-calls: if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil { call.errs <- err continue } waiters[streamID] = call streamID += 2 // enforce odd client initiated request ids case msg := <-incoming: call, ok := waiters[msg.StreamID] if !ok { log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID) continue } call.errs <- c.recv(call.resp, msg) delete(waiters, msg.StreamID) case <-shutdown: shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down") c.err = shutdownErr for _, waiter := range waiters { waiter.errs <- shutdownErr } c.Close() return case <-c.closed: // broadcast the shutdown error to the remaining waiters. for _, waiter := range waiters { waiter.errs <- shutdownErr } return } } } func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error { p, err := c.codec.Marshal(msg) if err != nil { return err } return c.channel.send(ctx, streamID, mtype, p) } func (c *Client) recv(resp *Response, msg *message) error { if msg.err != nil { return msg.err } if msg.Type != messageTypeResponse { return errors.New("unkown message type received") } defer c.channel.putmbuf(msg.p) return proto.Unmarshal(msg.p, resp) }