2017-11-29 18:52:16 -05:00
|
|
|
package ttrpc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"net"
|
|
|
|
"sync"
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
"github.com/containerd/containerd/log"
|
2017-11-29 18:52:16 -05:00
|
|
|
"github.com/gogo/protobuf/proto"
|
|
|
|
"github.com/pkg/errors"
|
|
|
|
"google.golang.org/grpc/status"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Client struct {
|
2017-11-30 11:11:33 -05:00
|
|
|
codec codec
|
|
|
|
conn net.Conn
|
|
|
|
channel *channel
|
|
|
|
calls chan *callRequest
|
2017-11-29 18:52:16 -05:00
|
|
|
|
|
|
|
closed chan struct{}
|
|
|
|
closeOnce sync.Once
|
|
|
|
done chan struct{}
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewClient(conn net.Conn) *Client {
|
|
|
|
c := &Client{
|
2017-11-30 11:11:33 -05:00
|
|
|
codec: codec{},
|
|
|
|
conn: conn,
|
|
|
|
channel: newChannel(conn, conn),
|
|
|
|
calls: make(chan *callRequest),
|
|
|
|
closed: make(chan struct{}),
|
|
|
|
done: make(chan struct{}),
|
2017-11-29 18:52:16 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
go c.run()
|
|
|
|
return c
|
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
type callRequest struct {
|
|
|
|
ctx context.Context
|
|
|
|
req *Request
|
|
|
|
resp *Response // response will be written back here
|
|
|
|
errs chan error // error written here on completion
|
|
|
|
}
|
|
|
|
|
2017-11-29 18:52:16 -05:00
|
|
|
func (c *Client) Call(ctx context.Context, service, method string, req, resp interface{}) error {
|
|
|
|
payload, err := c.codec.Marshal(req)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
var (
|
|
|
|
creq = &Request{
|
|
|
|
Service: service,
|
|
|
|
Method: method,
|
|
|
|
Payload: payload,
|
|
|
|
}
|
2017-11-29 18:52:16 -05:00
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
cresp = &Response{}
|
|
|
|
)
|
2017-11-29 18:52:16 -05:00
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
if err := c.dispatch(ctx, creq, cresp); err != nil {
|
2017-11-29 18:52:16 -05:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
if err := c.codec.Unmarshal(cresp.Payload, resp); err != nil {
|
2017-11-29 18:52:16 -05:00
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
if cresp.Status == nil {
|
2017-11-29 18:52:16 -05:00
|
|
|
return errors.New("no status provided on response")
|
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
return status.ErrorProto(cresp.Status)
|
2017-11-29 18:52:16 -05:00
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error {
|
2017-11-29 18:52:16 -05:00
|
|
|
errs := make(chan error, 1)
|
2017-11-30 11:11:33 -05:00
|
|
|
call := &callRequest{
|
|
|
|
req: req,
|
|
|
|
resp: resp,
|
|
|
|
errs: errs,
|
|
|
|
}
|
|
|
|
|
2017-11-29 18:52:16 -05:00
|
|
|
select {
|
2017-11-30 11:11:33 -05:00
|
|
|
case c.calls <- call:
|
2017-11-29 18:52:16 -05:00
|
|
|
case <-c.done:
|
|
|
|
return c.err
|
|
|
|
}
|
|
|
|
|
|
|
|
select {
|
|
|
|
case err := <-errs:
|
|
|
|
return err
|
|
|
|
case <-c.done:
|
|
|
|
return c.err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
func (c *Client) Close() error {
|
|
|
|
c.closeOnce.Do(func() {
|
|
|
|
close(c.closed)
|
|
|
|
})
|
2017-11-29 18:52:16 -05:00
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
return nil
|
2017-11-29 18:52:16 -05:00
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
type message struct {
|
|
|
|
messageHeader
|
2017-11-29 18:52:16 -05:00
|
|
|
p []byte
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Client) run() {
|
|
|
|
var (
|
2017-11-30 11:11:33 -05:00
|
|
|
streamID uint32 = 1
|
|
|
|
waiters = make(map[uint32]*callRequest)
|
|
|
|
calls = c.calls
|
|
|
|
incoming = make(chan *message)
|
|
|
|
shutdown = make(chan struct{})
|
|
|
|
shutdownErr error
|
2017-11-29 18:52:16 -05:00
|
|
|
)
|
|
|
|
|
|
|
|
go func() {
|
2017-11-30 11:11:33 -05:00
|
|
|
defer close(shutdown)
|
|
|
|
|
2017-11-29 18:52:16 -05:00
|
|
|
// start one more goroutine to recv messages without blocking.
|
|
|
|
for {
|
2017-11-30 11:11:33 -05:00
|
|
|
mh, p, err := c.channel.recv(context.TODO())
|
|
|
|
if err != nil {
|
|
|
|
_, ok := status.FromError(err)
|
|
|
|
if !ok {
|
|
|
|
// treat all errors that are not an rpc status as terminal.
|
|
|
|
// all others poison the connection.
|
|
|
|
shutdownErr = err
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
2017-11-29 18:52:16 -05:00
|
|
|
select {
|
2017-11-30 11:11:33 -05:00
|
|
|
case incoming <- &message{
|
|
|
|
messageHeader: mh,
|
|
|
|
p: p[:mh.Length],
|
|
|
|
err: err,
|
2017-11-29 18:52:16 -05:00
|
|
|
}:
|
|
|
|
case <-c.done:
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
defer c.conn.Close()
|
|
|
|
defer close(c.done)
|
|
|
|
|
2017-11-29 18:52:16 -05:00
|
|
|
for {
|
|
|
|
select {
|
2017-11-30 11:11:33 -05:00
|
|
|
case call := <-calls:
|
|
|
|
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
|
|
|
|
call.errs <- err
|
|
|
|
continue
|
2017-11-29 18:52:16 -05:00
|
|
|
}
|
2017-11-30 11:11:33 -05:00
|
|
|
|
|
|
|
waiters[streamID] = call
|
|
|
|
streamID += 2 // enforce odd client initiated request ids
|
|
|
|
case msg := <-incoming:
|
|
|
|
call, ok := waiters[msg.StreamID]
|
|
|
|
if !ok {
|
|
|
|
log.L.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
|
|
|
|
continue
|
2017-11-29 18:52:16 -05:00
|
|
|
}
|
|
|
|
|
2017-11-30 11:11:33 -05:00
|
|
|
call.errs <- c.recv(call.resp, msg)
|
|
|
|
delete(waiters, msg.StreamID)
|
|
|
|
case <-shutdown:
|
|
|
|
shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")
|
|
|
|
c.err = shutdownErr
|
|
|
|
for _, waiter := range waiters {
|
|
|
|
waiter.errs <- shutdownErr
|
2017-11-29 18:52:16 -05:00
|
|
|
}
|
2017-11-30 11:11:33 -05:00
|
|
|
c.Close()
|
|
|
|
return
|
2017-11-29 18:52:16 -05:00
|
|
|
case <-c.closed:
|
2017-11-30 11:11:33 -05:00
|
|
|
// broadcast the shutdown error to the remaining waiters.
|
|
|
|
for _, waiter := range waiters {
|
|
|
|
waiter.errs <- shutdownErr
|
|
|
|
}
|
2017-11-29 18:52:16 -05:00
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
2017-11-30 11:11:33 -05:00
|
|
|
|
|
|
|
func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
|
|
|
|
p, err := c.codec.Marshal(msg)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
return c.channel.send(ctx, streamID, mtype, p)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (c *Client) recv(resp *Response, msg *message) error {
|
|
|
|
if msg.err != nil {
|
|
|
|
return msg.err
|
|
|
|
}
|
|
|
|
|
|
|
|
if msg.Type != messageTypeResponse {
|
|
|
|
return errors.New("unkown message type received")
|
|
|
|
}
|
|
|
|
|
|
|
|
defer c.channel.putmbuf(msg.p)
|
|
|
|
return proto.Unmarshal(msg.p, resp)
|
|
|
|
}
|