1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00

bump google.golang.org/grpc v1.23.0

full diff: https://github.com/grpc/grpc-go/compare/v1.20.1...v1.23.0

This update contains security fixes:

- transport: block reading frames when too many transport control frames are queued (grpc/grpc-go#2970)
  - Addresses CVE-2019-9512 (Ping Flood), CVE-2019-9514 (Reset Flood), and CVE-2019-9515 (Settings Flood).

Other changes can be found in the release notes:
https://github.com/grpc/grpc-go/releases/tag/v1.23.0

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2019-08-26 16:56:14 +02:00
parent 7ce0e26c16
commit f1cd79976a
No known key found for this signature in database
GPG key ID: 76698F39D527CE8C
36 changed files with 1217 additions and 535 deletions

View file

@ -73,7 +73,7 @@ github.com/opencontainers/go-digest 279bed98673dd5bef374d3b6e4b0
# get go-zfs packages # get go-zfs packages
github.com/mistifyio/go-zfs f784269be439d704d3dfa1906f45dd848fed2beb github.com/mistifyio/go-zfs f784269be439d704d3dfa1906f45dd848fed2beb
google.golang.org/grpc 25c4f928eaa6d96443009bd842389fb4fa48664e # v1.20.1 google.golang.org/grpc 6eaf6f47437a6b4e2153a190160ef39a92c7eceb # v1.23.0
# The version of runc should match the version that is used by the containerd # The version of runc should match the version that is used by the containerd
# version that is used. If you need to update runc, open a pull request in # version that is used. If you need to update runc, open a pull request in

View file

@ -1,42 +1,96 @@
# gRPC-Go # gRPC-Go
[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc) [![GoReportCard](https://goreportcard.com/badge/grpc/grpc-go)](https://goreportcard.com/report/github.com/grpc/grpc-go) [![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go)
[![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc)
[![GoReportCard](https://goreportcard.com/badge/grpc/grpc-go)](https://goreportcard.com/report/github.com/grpc/grpc-go)
The Go implementation of [gRPC](https://grpc.io/): A high performance, open source, general RPC framework that puts mobile and HTTP/2 first. For more information see the [gRPC Quick Start: Go](https://grpc.io/docs/quickstart/go.html) guide. The Go implementation of [gRPC](https://grpc.io/): A high performance, open
source, general RPC framework that puts mobile and HTTP/2 first. For more
information see the [gRPC Quick Start:
Go](https://grpc.io/docs/quickstart/go.html) guide.
Installation Installation
------------ ------------
To install this package, you need to install Go and setup your Go workspace on your computer. The simplest way to install the library is to run: To install this package, you need to install Go and setup your Go workspace on
your computer. The simplest way to install the library is to run:
``` ```
$ go get -u google.golang.org/grpc $ go get -u google.golang.org/grpc
``` ```
With Go module support (Go 1.11+), simply `import "google.golang.org/grpc"` in
your source code and `go [build|run|test]` will automatically download the
necessary dependencies ([Go modules
ref](https://github.com/golang/go/wiki/Modules)).
If you are trying to access grpc-go from within China, please see the
[FAQ](#FAQ) below.
Prerequisites Prerequisites
------------- -------------
gRPC-Go requires Go 1.9 or later. gRPC-Go requires Go 1.9 or later.
Constraints
-----------
The grpc package should only depend on standard Go packages and a small number of exceptions. If your contribution introduces new dependencies which are NOT in the [list](https://godoc.org/google.golang.org/grpc?imports), you need a discussion with gRPC-Go authors and consultants.
Documentation Documentation
------------- -------------
See [API documentation](https://godoc.org/google.golang.org/grpc) for package and API descriptions and find examples in the [examples directory](examples/). - See [godoc](https://godoc.org/google.golang.org/grpc) for package and API
descriptions.
- Documentation on specific topics can be found in the [Documentation
directory](Documentation/).
- Examples can be found in the [examples directory](examples/).
Performance Performance
----------- -----------
See the current benchmarks for some of the languages supported in [this dashboard](https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696). Performance benchmark data for grpc-go and other languages is maintained in
[this
dashboard](https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696).
Status Status
------ ------
General Availability [Google Cloud Platform Launch Stages](https://cloud.google.com/terms/launch-stages). General Availability [Google Cloud Platform Launch
Stages](https://cloud.google.com/terms/launch-stages).
FAQ FAQ
--- ---
#### I/O Timeout Errors
The `golang.org` domain may be blocked from some countries. `go get` usually
produces an error like the following when this happens:
```
$ go get -u google.golang.org/grpc
package google.golang.org/grpc: unrecognized import path "google.golang.org/grpc" (https fetch: Get https://google.golang.org/grpc?go-get=1: dial tcp 216.239.37.1:443: i/o timeout)
```
To build Go code, there are several options:
- Set up a VPN and access google.golang.org through that.
- Without Go module support: `git clone` the repo manually:
```
git clone https://github.com/grpc/grpc-go.git $GOPATH/src/google.golang.org/grpc
```
You will need to do the same for all of grpc's dependencies in `golang.org`,
e.g. `golang.org/x/net`.
- With Go module support: it is possible to use the `replace` feature of `go
mod` to create aliases for golang.org packages. In your project's directory:
```
go mod edit -replace=google.golang.org/grpc=github.com/grpc/grpc-go@latest
go mod tidy
go mod vendor
go build -mod=vendor
```
Again, this will need to be done for all transitive dependencies hosted on
golang.org as well. Please refer to [this
issue](https://github.com/golang/go/issues/28652) in the golang repo regarding
this concern.
#### Compiling error, undefined: grpc.SupportPackageIsVersion #### Compiling error, undefined: grpc.SupportPackageIsVersion
Please update proto package, gRPC package and rebuild the proto files: Please update proto package, gRPC package and rebuild the proto files:

View file

@ -43,7 +43,7 @@ type Address struct {
// BalancerConfig specifies the configurations for Balancer. // BalancerConfig specifies the configurations for Balancer.
// //
// Deprecated: please use package balancer. // Deprecated: please use package balancer. May be removed in a future 1.x release.
type BalancerConfig struct { type BalancerConfig struct {
// DialCreds is the transport credential the Balancer implementation can // DialCreds is the transport credential the Balancer implementation can
// use to dial to a remote load balancer server. The Balancer implementations // use to dial to a remote load balancer server. The Balancer implementations
@ -57,7 +57,7 @@ type BalancerConfig struct {
// BalancerGetOptions configures a Get call. // BalancerGetOptions configures a Get call.
// //
// Deprecated: please use package balancer. // Deprecated: please use package balancer. May be removed in a future 1.x release.
type BalancerGetOptions struct { type BalancerGetOptions struct {
// BlockingWait specifies whether Get should block when there is no // BlockingWait specifies whether Get should block when there is no
// connected address. // connected address.
@ -66,7 +66,7 @@ type BalancerGetOptions struct {
// Balancer chooses network addresses for RPCs. // Balancer chooses network addresses for RPCs.
// //
// Deprecated: please use package balancer. // Deprecated: please use package balancer. May be removed in a future 1.x release.
type Balancer interface { type Balancer interface {
// Start does the initialization work to bootstrap a Balancer. For example, // Start does the initialization work to bootstrap a Balancer. For example,
// this function may start the name resolution and watch the updates. It will // this function may start the name resolution and watch the updates. It will
@ -120,7 +120,7 @@ type Balancer interface {
// RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch // RoundRobin returns a Balancer that selects addresses round-robin. It uses r to watch
// the name resolution updates and updates the addresses available correspondingly. // the name resolution updates and updates the addresses available correspondingly.
// //
// Deprecated: please use package balancer/roundrobin. // Deprecated: please use package balancer/roundrobin. May be removed in a future 1.x release.
func RoundRobin(r naming.Resolver) Balancer { func RoundRobin(r naming.Resolver) Balancer {
return &roundRobin{r: r} return &roundRobin{r: r}
} }

View file

@ -22,6 +22,7 @@ package balancer
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"net" "net"
"strings" "strings"
@ -31,6 +32,7 @@ import (
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"
) )
var ( var (
@ -39,7 +41,10 @@ var (
) )
// Register registers the balancer builder to the balancer map. b.Name // Register registers the balancer builder to the balancer map. b.Name
// (lowercased) will be used as the name registered with this builder. // (lowercased) will be used as the name registered with this builder. If the
// Builder implements ConfigParser, ParseConfig will be called when new service
// configs are received by the resolver, and the result will be provided to the
// Balancer in UpdateClientConnState.
// //
// NOTE: this function must only be called during initialization time (i.e. in // NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. If multiple Balancers are // an init() function), and is not thread-safe. If multiple Balancers are
@ -138,6 +143,8 @@ type ClientConn interface {
ResolveNow(resolver.ResolveNowOption) ResolveNow(resolver.ResolveNowOption)
// Target returns the dial target for this ClientConn. // Target returns the dial target for this ClientConn.
//
// Deprecated: Use the Target field in the BuildOptions instead.
Target() string Target() string
} }
@ -155,6 +162,10 @@ type BuildOptions struct {
Dialer func(context.Context, string) (net.Conn, error) Dialer func(context.Context, string) (net.Conn, error)
// ChannelzParentID is the entity parent's channelz unique identification number. // ChannelzParentID is the entity parent's channelz unique identification number.
ChannelzParentID int64 ChannelzParentID int64
// Target contains the parsed address info of the dial target. It is the same resolver.Target as
// passed to the resolver.
// See the documentation for the resolver.Target type for details about what it contains.
Target resolver.Target
} }
// Builder creates a balancer. // Builder creates a balancer.
@ -166,6 +177,14 @@ type Builder interface {
Name() string Name() string
} }
// ConfigParser parses load balancer configs.
type ConfigParser interface {
// ParseConfig parses the JSON load balancer config provided into an
// internal form or returns an error if the config is invalid. For future
// compatibility reasons, unknown fields in the config should be ignored.
ParseConfig(LoadBalancingConfigJSON json.RawMessage) (serviceconfig.LoadBalancingConfig, error)
}
// PickOptions contains addition information for the Pick operation. // PickOptions contains addition information for the Pick operation.
type PickOptions struct { type PickOptions struct {
// FullMethodName is the method name that NewClientStream() is called // FullMethodName is the method name that NewClientStream() is called
@ -264,7 +283,7 @@ type Balancer interface {
// non-nil error to gRPC. // non-nil error to gRPC.
// //
// Deprecated: if V2Balancer is implemented by the Balancer, // Deprecated: if V2Balancer is implemented by the Balancer,
// UpdateResolverState will be called instead. // UpdateClientConnState will be called instead.
HandleResolvedAddrs([]resolver.Address, error) HandleResolvedAddrs([]resolver.Address, error)
// Close closes the balancer. The balancer is not required to call // Close closes the balancer. The balancer is not required to call
// ClientConn.RemoveSubConn for its existing SubConns. // ClientConn.RemoveSubConn for its existing SubConns.
@ -277,14 +296,23 @@ type SubConnState struct {
// TODO: add last connection error // TODO: add last connection error
} }
// ClientConnState describes the state of a ClientConn relevant to the
// balancer.
type ClientConnState struct {
ResolverState resolver.State
// The parsed load balancing configuration returned by the builder's
// ParseConfig method, if implemented.
BalancerConfig serviceconfig.LoadBalancingConfig
}
// V2Balancer is defined for documentation purposes. If a Balancer also // V2Balancer is defined for documentation purposes. If a Balancer also
// implements V2Balancer, its UpdateResolverState method will be called instead // implements V2Balancer, its UpdateClientConnState method will be called
// of HandleResolvedAddrs and its UpdateSubConnState will be called instead of // instead of HandleResolvedAddrs and its UpdateSubConnState will be called
// HandleSubConnStateChange. // instead of HandleSubConnStateChange.
type V2Balancer interface { type V2Balancer interface {
// UpdateResolverState is called by gRPC when the state of the resolver // UpdateClientConnState is called by gRPC when the state of the ClientConn
// changes. // changes.
UpdateResolverState(resolver.State) UpdateClientConnState(ClientConnState)
// UpdateSubConnState is called by gRPC when the state of a SubConn // UpdateSubConnState is called by gRPC when the state of a SubConn
// changes. // changes.
UpdateSubConnState(SubConn, SubConnState) UpdateSubConnState(SubConn, SubConnState)

View file

@ -70,13 +70,15 @@ func (b *baseBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error)
panic("not implemented") panic("not implemented")
} }
func (b *baseBalancer) UpdateResolverState(s resolver.State) { func (b *baseBalancer) UpdateClientConnState(s balancer.ClientConnState) {
// TODO: handle s.Err (log if not nil) once implemented. // TODO: handle s.ResolverState.Err (log if not nil) once implemented.
// TODO: handle s.ServiceConfig? // TODO: handle s.ResolverState.ServiceConfig?
grpclog.Infoln("base.baseBalancer: got new resolver state: ", s) if grpclog.V(2) {
grpclog.Infoln("base.baseBalancer: got new ClientConn state: ", s)
}
// addrsSet is the set converted from addrs, it's used for quick lookup of an address. // addrsSet is the set converted from addrs, it's used for quick lookup of an address.
addrsSet := make(map[resolver.Address]struct{}) addrsSet := make(map[resolver.Address]struct{})
for _, a := range s.Addresses { for _, a := range s.ResolverState.Addresses {
addrsSet[a] = struct{}{} addrsSet[a] = struct{}{}
if _, ok := b.subConns[a]; !ok { if _, ok := b.subConns[a]; !ok {
// a is a new address (not existing in b.subConns). // a is a new address (not existing in b.subConns).
@ -127,10 +129,14 @@ func (b *baseBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectiv
func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { func (b *baseBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) {
s := state.ConnectivityState s := state.ConnectivityState
grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s) if grpclog.V(2) {
grpclog.Infof("base.baseBalancer: handle SubConn state change: %p, %v", sc, s)
}
oldS, ok := b.scStates[sc] oldS, ok := b.scStates[sc]
if !ok { if !ok {
grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s) if grpclog.V(2) {
grpclog.Infof("base.baseBalancer: got state changes for an unknown SubConn: %p, %v", sc, s)
}
return return
} }
b.scStates[sc] = s b.scStates[sc] = s

View file

@ -88,7 +88,7 @@ type ccBalancerWrapper struct {
cc *ClientConn cc *ClientConn
balancer balancer.Balancer balancer balancer.Balancer
stateChangeQueue *scStateUpdateBuffer stateChangeQueue *scStateUpdateBuffer
resolverUpdateCh chan *resolver.State ccUpdateCh chan *balancer.ClientConnState
done chan struct{} done chan struct{}
mu sync.Mutex mu sync.Mutex
@ -99,7 +99,7 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui
ccb := &ccBalancerWrapper{ ccb := &ccBalancerWrapper{
cc: cc, cc: cc,
stateChangeQueue: newSCStateUpdateBuffer(), stateChangeQueue: newSCStateUpdateBuffer(),
resolverUpdateCh: make(chan *resolver.State, 1), ccUpdateCh: make(chan *balancer.ClientConnState, 1),
done: make(chan struct{}), done: make(chan struct{}),
subConns: make(map[*acBalancerWrapper]struct{}), subConns: make(map[*acBalancerWrapper]struct{}),
} }
@ -126,7 +126,7 @@ func (ccb *ccBalancerWrapper) watcher() {
} else { } else {
ccb.balancer.HandleSubConnStateChange(t.sc, t.state) ccb.balancer.HandleSubConnStateChange(t.sc, t.state)
} }
case s := <-ccb.resolverUpdateCh: case s := <-ccb.ccUpdateCh:
select { select {
case <-ccb.done: case <-ccb.done:
ccb.balancer.Close() ccb.balancer.Close()
@ -134,9 +134,9 @@ func (ccb *ccBalancerWrapper) watcher() {
default: default:
} }
if ub, ok := ccb.balancer.(balancer.V2Balancer); ok { if ub, ok := ccb.balancer.(balancer.V2Balancer); ok {
ub.UpdateResolverState(*s) ub.UpdateClientConnState(*s)
} else { } else {
ccb.balancer.HandleResolvedAddrs(s.Addresses, nil) ccb.balancer.HandleResolvedAddrs(s.ResolverState.Addresses, nil)
} }
case <-ccb.done: case <-ccb.done:
} }
@ -151,9 +151,11 @@ func (ccb *ccBalancerWrapper) watcher() {
for acbw := range scs { for acbw := range scs {
ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain)
} }
ccb.UpdateBalancerState(connectivity.Connecting, nil)
return return
default: default:
} }
ccb.cc.firstResolveEvent.Fire()
} }
} }
@ -178,9 +180,10 @@ func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s co
}) })
} }
func (ccb *ccBalancerWrapper) updateResolverState(s resolver.State) { func (ccb *ccBalancerWrapper) updateClientConnState(ccs *balancer.ClientConnState) {
if ccb.cc.curBalancerName != grpclbName { if ccb.cc.curBalancerName != grpclbName {
// Filter any grpclb addresses since we don't have the grpclb balancer. // Filter any grpclb addresses since we don't have the grpclb balancer.
s := &ccs.ResolverState
for i := 0; i < len(s.Addresses); { for i := 0; i < len(s.Addresses); {
if s.Addresses[i].Type == resolver.GRPCLB { if s.Addresses[i].Type == resolver.GRPCLB {
copy(s.Addresses[i:], s.Addresses[i+1:]) copy(s.Addresses[i:], s.Addresses[i+1:])
@ -191,10 +194,10 @@ func (ccb *ccBalancerWrapper) updateResolverState(s resolver.State) {
} }
} }
select { select {
case <-ccb.resolverUpdateCh: case <-ccb.ccUpdateCh:
default: default:
} }
ccb.resolverUpdateCh <- &s ccb.ccUpdateCh <- ccs
} }
func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {

View file

@ -20,7 +20,6 @@ package grpc
import ( import (
"context" "context"
"strings"
"sync" "sync"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
@ -34,13 +33,7 @@ type balancerWrapperBuilder struct {
} }
func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
targetAddr := cc.Target() bwb.b.Start(opts.Target.Endpoint, BalancerConfig{
targetSplitted := strings.Split(targetAddr, ":///")
if len(targetSplitted) >= 2 {
targetAddr = targetSplitted[1]
}
bwb.b.Start(targetAddr, BalancerConfig{
DialCreds: opts.DialCreds, DialCreds: opts.DialCreds,
Dialer: opts.Dialer, Dialer: opts.Dialer,
}) })
@ -49,7 +42,7 @@ func (bwb *balancerWrapperBuilder) Build(cc balancer.ClientConn, opts balancer.B
balancer: bwb.b, balancer: bwb.b,
pickfirst: pickfirst, pickfirst: pickfirst,
cc: cc, cc: cc,
targetAddr: targetAddr, targetAddr: opts.Target.Endpoint,
startCh: make(chan struct{}), startCh: make(chan struct{}),
conns: make(map[resolver.Address]balancer.SubConn), conns: make(map[resolver.Address]balancer.SubConn),
connSt: make(map[balancer.SubConn]*scState), connSt: make(map[balancer.SubConn]*scState),
@ -120,7 +113,7 @@ func (bw *balancerWrapper) lbWatcher() {
} }
for addrs := range notifyCh { for addrs := range notifyCh {
grpclog.Infof("balancerWrapper: got update addr from Notify: %v\n", addrs) grpclog.Infof("balancerWrapper: got update addr from Notify: %v", addrs)
if bw.pickfirst { if bw.pickfirst {
var ( var (
oldA resolver.Address oldA resolver.Address

View file

@ -38,13 +38,13 @@ import (
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
_ "google.golang.org/grpc/resolver/dns" // To register dns resolver. _ "google.golang.org/grpc/resolver/dns" // To register dns resolver.
_ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver. _ "google.golang.org/grpc/resolver/passthrough" // To register passthrough resolver.
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -137,6 +137,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
opt.apply(&cc.dopts) opt.apply(&cc.dopts)
} }
chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc)
defer func() { defer func() {
if err != nil { if err != nil {
cc.Close() cc.Close()
@ -290,6 +293,7 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
CredsBundle: cc.dopts.copts.CredsBundle, CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer, Dialer: cc.dopts.copts.Dialer,
ChannelzParentID: cc.channelzID, ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
} }
// Build the resolver. // Build the resolver.
@ -327,6 +331,68 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
return cc, nil return cc, nil
} }
// chainUnaryClientInterceptors chains all unary client interceptors into one.
func chainUnaryClientInterceptors(cc *ClientConn) {
interceptors := cc.dopts.chainUnaryInts
// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
// be executed before any other chained interceptors.
if cc.dopts.unaryInt != nil {
interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
}
var chainedInt UnaryClientInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
}
}
cc.dopts.unaryInt = chainedInt
}
// getChainUnaryInvoker recursively generate the chained unary invoker.
func getChainUnaryInvoker(interceptors []UnaryClientInterceptor, curr int, finalInvoker UnaryInvoker) UnaryInvoker {
if curr == len(interceptors)-1 {
return finalInvoker
}
return func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
return interceptors[curr+1](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, curr+1, finalInvoker), opts...)
}
}
// chainStreamClientInterceptors chains all stream client interceptors into one.
func chainStreamClientInterceptors(cc *ClientConn) {
interceptors := cc.dopts.chainStreamInts
// Prepend dopts.streamInt to the chaining interceptors if it exists, since streamInt will
// be executed before any other chained interceptors.
if cc.dopts.streamInt != nil {
interceptors = append([]StreamClientInterceptor{cc.dopts.streamInt}, interceptors...)
}
var chainedInt StreamClientInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) {
return interceptors[0](ctx, desc, cc, method, getChainStreamer(interceptors, 0, streamer), opts...)
}
}
cc.dopts.streamInt = chainedInt
}
// getChainStreamer recursively generate the chained client stream constructor.
func getChainStreamer(interceptors []StreamClientInterceptor, curr int, finalStreamer Streamer) Streamer {
if curr == len(interceptors)-1 {
return finalStreamer
}
return func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
return interceptors[curr+1](ctx, desc, cc, method, getChainStreamer(interceptors, curr+1, finalStreamer), opts...)
}
}
// connectivityStateManager keeps the connectivity.State of ClientConn. // connectivityStateManager keeps the connectivity.State of ClientConn.
// This struct will eventually be exported so the balancers can access it. // This struct will eventually be exported so the balancers can access it.
type connectivityStateManager struct { type connectivityStateManager struct {
@ -466,24 +532,6 @@ func (cc *ClientConn) waitForResolvedAddrs(ctx context.Context) error {
} }
} }
// gRPC should resort to default service config when:
// * resolver service config is disabled
// * or, resolver does not return a service config or returns an invalid one.
func (cc *ClientConn) fallbackToDefaultServiceConfig(sc string) bool {
if cc.dopts.disableServiceConfig {
return true
}
// The logic below is temporary, will be removed once we change the resolver.State ServiceConfig field type.
// Right now, we assume that empty service config string means resolver does not return a config.
if sc == "" {
return true
}
// TODO: the logic below is temporary. Once we finish the logic to validate service config
// in resolver, we will replace the logic below.
_, err := parseServiceConfig(sc)
return err != nil
}
func (cc *ClientConn) updateResolverState(s resolver.State) error { func (cc *ClientConn) updateResolverState(s resolver.State) error {
cc.mu.Lock() cc.mu.Lock()
defer cc.mu.Unlock() defer cc.mu.Unlock()
@ -494,54 +542,47 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error {
return nil return nil
} }
if cc.fallbackToDefaultServiceConfig(s.ServiceConfig) { if cc.dopts.disableServiceConfig || s.ServiceConfig == nil {
if cc.dopts.defaultServiceConfig != nil && cc.sc == nil { if cc.dopts.defaultServiceConfig != nil && cc.sc == nil {
cc.applyServiceConfig(cc.dopts.defaultServiceConfig) cc.applyServiceConfig(cc.dopts.defaultServiceConfig)
} }
} else { } else if sc, ok := s.ServiceConfig.(*ServiceConfig); ok {
// TODO: the parsing logic below will be moved inside resolver. cc.applyServiceConfig(sc)
sc, err := parseServiceConfig(s.ServiceConfig)
if err != nil {
return err
}
if cc.sc == nil || cc.sc.rawJSONString != s.ServiceConfig {
cc.applyServiceConfig(sc)
}
}
// update the service config that will be sent to balancer.
if cc.sc != nil {
s.ServiceConfig = cc.sc.rawJSONString
} }
var balCfg serviceconfig.LoadBalancingConfig
if cc.dopts.balancerBuilder == nil { if cc.dopts.balancerBuilder == nil {
// Only look at balancer types and switch balancer if balancer dial // Only look at balancer types and switch balancer if balancer dial
// option is not set. // option is not set.
var isGRPCLB bool
for _, a := range s.Addresses {
if a.Type == resolver.GRPCLB {
isGRPCLB = true
break
}
}
var newBalancerName string var newBalancerName string
// TODO: use new loadBalancerConfig field with appropriate priority. if cc.sc != nil && cc.sc.lbConfig != nil {
if isGRPCLB { newBalancerName = cc.sc.lbConfig.name
newBalancerName = grpclbName balCfg = cc.sc.lbConfig.cfg
} else if cc.sc != nil && cc.sc.LB != nil {
newBalancerName = *cc.sc.LB
} else { } else {
newBalancerName = PickFirstBalancerName var isGRPCLB bool
for _, a := range s.Addresses {
if a.Type == resolver.GRPCLB {
isGRPCLB = true
break
}
}
if isGRPCLB {
newBalancerName = grpclbName
} else if cc.sc != nil && cc.sc.LB != nil {
newBalancerName = *cc.sc.LB
} else {
newBalancerName = PickFirstBalancerName
}
} }
cc.switchBalancer(newBalancerName) cc.switchBalancer(newBalancerName)
} else if cc.balancerWrapper == nil { } else if cc.balancerWrapper == nil {
// Balancer dial option was set, and this is the first time handling // Balancer dial option was set, and this is the first time handling
// resolved addresses. Build a balancer with dopts.balancerBuilder. // resolved addresses. Build a balancer with dopts.balancerBuilder.
cc.curBalancerName = cc.dopts.balancerBuilder.Name()
cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts) cc.balancerWrapper = newCCBalancerWrapper(cc, cc.dopts.balancerBuilder, cc.balancerBuildOpts)
} }
cc.balancerWrapper.updateResolverState(s) cc.balancerWrapper.updateClientConnState(&balancer.ClientConnState{ResolverState: s, BalancerConfig: balCfg})
cc.firstResolveEvent.Fire()
return nil return nil
} }
@ -554,7 +595,7 @@ func (cc *ClientConn) updateResolverState(s resolver.State) error {
// //
// Caller must hold cc.mu. // Caller must hold cc.mu.
func (cc *ClientConn) switchBalancer(name string) { func (cc *ClientConn) switchBalancer(name string) {
if strings.ToLower(cc.curBalancerName) == strings.ToLower(name) { if strings.EqualFold(cc.curBalancerName, name) {
return return
} }
@ -693,6 +734,8 @@ func (ac *addrConn) connect() error {
ac.mu.Unlock() ac.mu.Unlock()
return nil return nil
} }
// Update connectivity state within the lock to prevent subsequent or
// concurrent calls from resetting the transport more than once.
ac.updateConnectivityState(connectivity.Connecting) ac.updateConnectivityState(connectivity.Connecting)
ac.mu.Unlock() ac.mu.Unlock()
@ -703,7 +746,16 @@ func (ac *addrConn) connect() error {
// tryUpdateAddrs tries to update ac.addrs with the new addresses list. // tryUpdateAddrs tries to update ac.addrs with the new addresses list.
// //
// It checks whether current connected address of ac is in the new addrs list. // If ac is Connecting, it returns false. The caller should tear down the ac and
// create a new one. Note that the backoff will be reset when this happens.
//
// If ac is TransientFailure, it updates ac.addrs and returns true. The updated
// addresses will be picked up by retry in the next iteration after backoff.
//
// If ac is Shutdown or Idle, it updates ac.addrs and returns true.
//
// If ac is Ready, it checks whether current connected address of ac is in the
// new addrs list.
// - If true, it updates ac.addrs and returns true. The ac will keep using // - If true, it updates ac.addrs and returns true. The ac will keep using
// the existing connection. // the existing connection.
// - If false, it does nothing and returns false. // - If false, it does nothing and returns false.
@ -711,17 +763,18 @@ func (ac *addrConn) tryUpdateAddrs(addrs []resolver.Address) bool {
ac.mu.Lock() ac.mu.Lock()
defer ac.mu.Unlock() defer ac.mu.Unlock()
grpclog.Infof("addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs) grpclog.Infof("addrConn: tryUpdateAddrs curAddr: %v, addrs: %v", ac.curAddr, addrs)
if ac.state == connectivity.Shutdown { if ac.state == connectivity.Shutdown ||
ac.state == connectivity.TransientFailure ||
ac.state == connectivity.Idle {
ac.addrs = addrs ac.addrs = addrs
return true return true
} }
// Unless we're busy reconnecting already, let's reconnect from the top of if ac.state == connectivity.Connecting {
// the list.
if ac.state != connectivity.Ready {
return false return false
} }
// ac.state is Ready, try to find the connected address.
var curAddrFound bool var curAddrFound bool
for _, a := range addrs { for _, a := range addrs {
if reflect.DeepEqual(ac.curAddr, a) { if reflect.DeepEqual(ac.curAddr, a) {
@ -970,6 +1023,9 @@ func (ac *addrConn) resetTransport() {
// The spec doesn't mention what should be done for multiple addresses. // The spec doesn't mention what should be done for multiple addresses.
// https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm // https://github.com/grpc/grpc/blob/master/doc/connection-backoff.md#proposed-backoff-algorithm
connectDeadline := time.Now().Add(dialDuration) connectDeadline := time.Now().Add(dialDuration)
ac.updateConnectivityState(connectivity.Connecting)
ac.transport = nil
ac.mu.Unlock() ac.mu.Unlock()
newTr, addr, reconnect, err := ac.tryAllAddrs(addrs, connectDeadline) newTr, addr, reconnect, err := ac.tryAllAddrs(addrs, connectDeadline)
@ -1004,55 +1060,32 @@ func (ac *addrConn) resetTransport() {
ac.mu.Lock() ac.mu.Lock()
if ac.state == connectivity.Shutdown { if ac.state == connectivity.Shutdown {
newTr.Close()
ac.mu.Unlock() ac.mu.Unlock()
newTr.Close()
return return
} }
ac.curAddr = addr ac.curAddr = addr
ac.transport = newTr ac.transport = newTr
ac.backoffIdx = 0 ac.backoffIdx = 0
healthCheckConfig := ac.cc.healthCheckConfig()
// LB channel health checking is only enabled when all the four requirements below are met:
// 1. it is not disabled by the user with the WithDisableHealthCheck DialOption,
// 2. the internal.HealthCheckFunc is set by importing the grpc/healthcheck package,
// 3. a service config with non-empty healthCheckConfig field is provided,
// 4. the current load balancer allows it.
hctx, hcancel := context.WithCancel(ac.ctx) hctx, hcancel := context.WithCancel(ac.ctx)
healthcheckManagingState := false ac.startHealthCheck(hctx)
if !ac.cc.dopts.disableHealthCheck && healthCheckConfig != nil && ac.scopts.HealthCheckEnabled {
if ac.cc.dopts.healthCheckFunc == nil {
// TODO: add a link to the health check doc in the error message.
grpclog.Error("the client side LB channel health check function has not been set.")
} else {
// TODO(deklerk) refactor to just return transport
go ac.startHealthCheck(hctx, newTr, addr, healthCheckConfig.ServiceName)
healthcheckManagingState = true
}
}
if !healthcheckManagingState {
ac.updateConnectivityState(connectivity.Ready)
}
ac.mu.Unlock() ac.mu.Unlock()
// Block until the created transport is down. And when this happens, // Block until the created transport is down. And when this happens,
// we restart from the top of the addr list. // we restart from the top of the addr list.
<-reconnect.Done() <-reconnect.Done()
hcancel() hcancel()
// restart connecting - the top of the loop will set state to
// Need to reconnect after a READY, the addrConn enters // CONNECTING. This is against the current connectivity semantics doc,
// TRANSIENT_FAILURE. // however it allows for graceful behavior for RPCs not yet dispatched
// - unfortunate timing would otherwise lead to the RPC failing even
// though the TRANSIENT_FAILURE state (called for by the doc) would be
// instantaneous.
// //
// This will set addrConn to TRANSIENT_FAILURE for a very short period // Ideally we should transition to Idle here and block until there is
// of time, and turns CONNECTING. It seems reasonable to skip this, but // RPC activity that leads to the balancer requesting a reconnect of
// READY-CONNECTING is not a valid transition. // the associated SubConn.
ac.mu.Lock()
if ac.state == connectivity.Shutdown {
ac.mu.Unlock()
return
}
ac.updateConnectivityState(connectivity.TransientFailure)
ac.mu.Unlock()
} }
} }
@ -1066,8 +1099,6 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
ac.mu.Unlock() ac.mu.Unlock()
return nil, resolver.Address{}, nil, errConnClosing return nil, resolver.Address{}, nil, errConnClosing
} }
ac.updateConnectivityState(connectivity.Connecting)
ac.transport = nil
ac.cc.mu.RLock() ac.cc.mu.RLock()
ac.dopts.copts.KeepaliveParams = ac.cc.mkp ac.dopts.copts.KeepaliveParams = ac.cc.mkp
@ -1111,14 +1142,35 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
Authority: ac.cc.authority, Authority: ac.cc.authority,
} }
once := sync.Once{}
onGoAway := func(r transport.GoAwayReason) { onGoAway := func(r transport.GoAwayReason) {
ac.mu.Lock() ac.mu.Lock()
ac.adjustParams(r) ac.adjustParams(r)
once.Do(func() {
if ac.state == connectivity.Ready {
// Prevent this SubConn from being used for new RPCs by setting its
// state to Connecting.
//
// TODO: this should be Idle when grpc-go properly supports it.
ac.updateConnectivityState(connectivity.Connecting)
}
})
ac.mu.Unlock() ac.mu.Unlock()
reconnect.Fire() reconnect.Fire()
} }
onClose := func() { onClose := func() {
ac.mu.Lock()
once.Do(func() {
if ac.state == connectivity.Ready {
// Prevent this SubConn from being used for new RPCs by setting its
// state to Connecting.
//
// TODO: this should be Idle when grpc-go properly supports it.
ac.updateConnectivityState(connectivity.Connecting)
}
})
ac.mu.Unlock()
close(onCloseCalled) close(onCloseCalled)
reconnect.Fire() reconnect.Fire()
} }
@ -1140,60 +1192,99 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
return nil, nil, err return nil, nil, err
} }
if ac.dopts.reqHandshake == envconfig.RequireHandshakeOn { select {
select { case <-time.After(connectDeadline.Sub(time.Now())):
case <-time.After(connectDeadline.Sub(time.Now())): // We didn't get the preface in time.
// We didn't get the preface in time. newTr.Close()
newTr.Close() grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr)
grpclog.Warningf("grpc: addrConn.createTransport failed to connect to %v: didn't receive server preface in time. Reconnecting...", addr) return nil, nil, errors.New("timed out waiting for server handshake")
return nil, nil, errors.New("timed out waiting for server handshake") case <-prefaceReceived:
case <-prefaceReceived: // We got the preface - huzzah! things are good.
// We got the preface - huzzah! things are good. case <-onCloseCalled:
case <-onCloseCalled: // The transport has already closed - noop.
// The transport has already closed - noop. return nil, nil, errors.New("connection closed")
return nil, nil, errors.New("connection closed") // TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix.
// TODO(deklerk) this should bail on ac.ctx.Done(). Add a test and fix.
}
} }
return newTr, reconnect, nil return newTr, reconnect, nil
} }
func (ac *addrConn) startHealthCheck(ctx context.Context, newTr transport.ClientTransport, addr resolver.Address, serviceName string) { // startHealthCheck starts the health checking stream (RPC) to watch the health
// Set up the health check helper functions // stats of this connection if health checking is requested and configured.
newStream := func() (interface{}, error) { //
return ac.newClientStream(ctx, &StreamDesc{ServerStreams: true}, "/grpc.health.v1.Health/Watch", newTr) // LB channel health checking is enabled when all requirements below are met:
// 1. it is not disabled by the user with the WithDisableHealthCheck DialOption
// 2. internal.HealthCheckFunc is set by importing the grpc/healthcheck package
// 3. a service config with non-empty healthCheckConfig field is provided
// 4. the load balancer requests it
//
// It sets addrConn to READY if the health checking stream is not started.
//
// Caller must hold ac.mu.
func (ac *addrConn) startHealthCheck(ctx context.Context) {
var healthcheckManagingState bool
defer func() {
if !healthcheckManagingState {
ac.updateConnectivityState(connectivity.Ready)
}
}()
if ac.cc.dopts.disableHealthCheck {
return
} }
firstReady := true healthCheckConfig := ac.cc.healthCheckConfig()
reportHealth := func(ok bool) { if healthCheckConfig == nil {
return
}
if !ac.scopts.HealthCheckEnabled {
return
}
healthCheckFunc := ac.cc.dopts.healthCheckFunc
if healthCheckFunc == nil {
// The health package is not imported to set health check function.
//
// TODO: add a link to the health check doc in the error message.
grpclog.Error("Health check is requested but health check function is not set.")
return
}
healthcheckManagingState = true
// Set up the health check helper functions.
currentTr := ac.transport
newStream := func(method string) (interface{}, error) {
ac.mu.Lock()
if ac.transport != currentTr {
ac.mu.Unlock()
return nil, status.Error(codes.Canceled, "the provided transport is no longer valid to use")
}
ac.mu.Unlock()
return newNonRetryClientStream(ctx, &StreamDesc{ServerStreams: true}, method, currentTr, ac)
}
setConnectivityState := func(s connectivity.State) {
ac.mu.Lock() ac.mu.Lock()
defer ac.mu.Unlock() defer ac.mu.Unlock()
if ac.transport != newTr { if ac.transport != currentTr {
return return
} }
if ok { ac.updateConnectivityState(s)
if firstReady {
firstReady = false
ac.curAddr = addr
}
ac.updateConnectivityState(connectivity.Ready)
} else {
ac.updateConnectivityState(connectivity.TransientFailure)
}
} }
err := ac.cc.dopts.healthCheckFunc(ctx, newStream, reportHealth, serviceName) // Start the health checking stream.
if err != nil { go func() {
if status.Code(err) == codes.Unimplemented { err := ac.cc.dopts.healthCheckFunc(ctx, newStream, setConnectivityState, healthCheckConfig.ServiceName)
if channelz.IsOn() { if err != nil {
channelz.AddTraceEvent(ac.channelzID, &channelz.TraceEventDesc{ if status.Code(err) == codes.Unimplemented {
Desc: "Subchannel health check is unimplemented at server side, thus health check is disabled", if channelz.IsOn() {
Severity: channelz.CtError, channelz.AddTraceEvent(ac.channelzID, &channelz.TraceEventDesc{
}) Desc: "Subchannel health check is unimplemented at server side, thus health check is disabled",
Severity: channelz.CtError,
})
}
grpclog.Error("Subchannel health check is unimplemented at server side, thus health check is disabled")
} else {
grpclog.Errorf("HealthCheckFunc exits with unexpected error %v", err)
} }
grpclog.Error("Subchannel health check is unimplemented at server side, thus health check is disabled")
} else {
grpclog.Errorf("HealthCheckFunc exits with unexpected error %v", err)
} }
} }()
} }
func (ac *addrConn) resetConnectBackoff() { func (ac *addrConn) resetConnectBackoff() {

View file

@ -132,7 +132,8 @@ const (
// Unavailable indicates the service is currently unavailable. // Unavailable indicates the service is currently unavailable.
// This is a most likely a transient condition and may be corrected // This is a most likely a transient condition and may be corrected
// by retrying with a backoff. // by retrying with a backoff. Note that it is not always safe to retry
// non-idempotent operations.
// //
// See litmus test above for deciding between FailedPrecondition, // See litmus test above for deciding between FailedPrecondition,
// Aborted, and Unavailable. // Aborted, and Unavailable.

View file

@ -278,24 +278,22 @@ type ChannelzSecurityValue interface {
// TLSChannelzSecurityValue defines the struct that TLS protocol should return // TLSChannelzSecurityValue defines the struct that TLS protocol should return
// from GetSecurityValue(), containing security info like cipher and certificate used. // from GetSecurityValue(), containing security info like cipher and certificate used.
type TLSChannelzSecurityValue struct { type TLSChannelzSecurityValue struct {
ChannelzSecurityValue
StandardName string StandardName string
LocalCertificate []byte LocalCertificate []byte
RemoteCertificate []byte RemoteCertificate []byte
} }
func (*TLSChannelzSecurityValue) isChannelzSecurityValue() {}
// OtherChannelzSecurityValue defines the struct that non-TLS protocol should return // OtherChannelzSecurityValue defines the struct that non-TLS protocol should return
// from GetSecurityValue(), which contains protocol specific security info. Note // from GetSecurityValue(), which contains protocol specific security info. Note
// the Value field will be sent to users of channelz requesting channel info, and // the Value field will be sent to users of channelz requesting channel info, and
// thus sensitive info should better be avoided. // thus sensitive info should better be avoided.
type OtherChannelzSecurityValue struct { type OtherChannelzSecurityValue struct {
ChannelzSecurityValue
Name string Name string
Value proto.Message Value proto.Message
} }
func (*OtherChannelzSecurityValue) isChannelzSecurityValue() {}
var cipherSuiteLookup = map[uint16]string{ var cipherSuiteLookup = map[uint16]string{
tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA", tls.TLS_RSA_WITH_RC4_128_SHA: "TLS_RSA_WITH_RC4_128_SHA",
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA", tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",

View file

@ -39,8 +39,12 @@ import (
// dialOptions configure a Dial call. dialOptions are set by the DialOption // dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial. // values passed to Dial.
type dialOptions struct { type dialOptions struct {
unaryInt UnaryClientInterceptor unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor streamInt StreamClientInterceptor
chainUnaryInts []UnaryClientInterceptor
chainStreamInts []StreamClientInterceptor
cp Compressor cp Compressor
dc Decompressor dc Decompressor
bs backoff.Strategy bs backoff.Strategy
@ -56,7 +60,6 @@ type dialOptions struct {
balancerBuilder balancer.Builder balancerBuilder balancer.Builder
// This is to support grpclb. // This is to support grpclb.
resolverBuilder resolver.Builder resolverBuilder resolver.Builder
reqHandshake envconfig.RequireHandshakeSetting
channelzParentID int64 channelzParentID int64
disableServiceConfig bool disableServiceConfig bool
disableRetry bool disableRetry bool
@ -96,17 +99,6 @@ func newFuncDialOption(f func(*dialOptions)) *funcDialOption {
} }
} }
// WithWaitForHandshake blocks until the initial settings frame is received from
// the server before assigning RPCs to the connection.
//
// Deprecated: this is the default behavior, and this option will be removed
// after the 1.18 release.
func WithWaitForHandshake() DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.reqHandshake = envconfig.RequireHandshakeOn
})
}
// WithWriteBufferSize determines how much data can be batched before doing a // WithWriteBufferSize determines how much data can be batched before doing a
// write on the wire. The corresponding memory allocation for this buffer will // write on the wire. The corresponding memory allocation for this buffer will
// be twice the size to keep syscalls low. The default value for this buffer is // be twice the size to keep syscalls low. The default value for this buffer is
@ -152,7 +144,8 @@ func WithInitialConnWindowSize(s int32) DialOption {
// WithMaxMsgSize returns a DialOption which sets the maximum message size the // WithMaxMsgSize returns a DialOption which sets the maximum message size the
// client can receive. // client can receive.
// //
// Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead. // Deprecated: use WithDefaultCallOptions(MaxCallRecvMsgSize(s)) instead. Will
// be supported throughout 1.x.
func WithMaxMsgSize(s int) DialOption { func WithMaxMsgSize(s int) DialOption {
return WithDefaultCallOptions(MaxCallRecvMsgSize(s)) return WithDefaultCallOptions(MaxCallRecvMsgSize(s))
} }
@ -168,7 +161,8 @@ func WithDefaultCallOptions(cos ...CallOption) DialOption {
// WithCodec returns a DialOption which sets a codec for message marshaling and // WithCodec returns a DialOption which sets a codec for message marshaling and
// unmarshaling. // unmarshaling.
// //
// Deprecated: use WithDefaultCallOptions(ForceCodec(_)) instead. // Deprecated: use WithDefaultCallOptions(ForceCodec(_)) instead. Will be
// supported throughout 1.x.
func WithCodec(c Codec) DialOption { func WithCodec(c Codec) DialOption {
return WithDefaultCallOptions(CallCustomCodec(c)) return WithDefaultCallOptions(CallCustomCodec(c))
} }
@ -177,7 +171,7 @@ func WithCodec(c Codec) DialOption {
// message compression. It has lower priority than the compressor set by the // message compression. It has lower priority than the compressor set by the
// UseCompressor CallOption. // UseCompressor CallOption.
// //
// Deprecated: use UseCompressor instead. // Deprecated: use UseCompressor instead. Will be supported throughout 1.x.
func WithCompressor(cp Compressor) DialOption { func WithCompressor(cp Compressor) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.cp = cp o.cp = cp
@ -192,7 +186,8 @@ func WithCompressor(cp Compressor) DialOption {
// message. If no compressor is registered for the encoding, an Unimplemented // message. If no compressor is registered for the encoding, an Unimplemented
// status error will be returned. // status error will be returned.
// //
// Deprecated: use encoding.RegisterCompressor instead. // Deprecated: use encoding.RegisterCompressor instead. Will be supported
// throughout 1.x.
func WithDecompressor(dc Decompressor) DialOption { func WithDecompressor(dc Decompressor) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.dc = dc o.dc = dc
@ -203,7 +198,7 @@ func WithDecompressor(dc Decompressor) DialOption {
// Name resolver will be ignored if this DialOption is specified. // Name resolver will be ignored if this DialOption is specified.
// //
// Deprecated: use the new balancer APIs in balancer package and // Deprecated: use the new balancer APIs in balancer package and
// WithBalancerName. // WithBalancerName. Will be removed in a future 1.x release.
func WithBalancer(b Balancer) DialOption { func WithBalancer(b Balancer) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.balancerBuilder = &balancerWrapperBuilder{ o.balancerBuilder = &balancerWrapperBuilder{
@ -219,7 +214,8 @@ func WithBalancer(b Balancer) DialOption {
// The balancer cannot be overridden by balancer option specified by service // The balancer cannot be overridden by balancer option specified by service
// config. // config.
// //
// This is an EXPERIMENTAL API. // Deprecated: use WithDefaultServiceConfig and WithDisableServiceConfig
// instead. Will be removed in a future 1.x release.
func WithBalancerName(balancerName string) DialOption { func WithBalancerName(balancerName string) DialOption {
builder := balancer.Get(balancerName) builder := balancer.Get(balancerName)
if builder == nil { if builder == nil {
@ -240,9 +236,10 @@ func withResolverBuilder(b resolver.Builder) DialOption {
// WithServiceConfig returns a DialOption which has a channel to read the // WithServiceConfig returns a DialOption which has a channel to read the
// service configuration. // service configuration.
// //
// Deprecated: service config should be received through name resolver, as // Deprecated: service config should be received through name resolver or via
// specified here. // WithDefaultServiceConfig, as specified at
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md. Will be
// removed in a future 1.x release.
func WithServiceConfig(c <-chan ServiceConfig) DialOption { func WithServiceConfig(c <-chan ServiceConfig) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.scChan = c o.scChan = c
@ -325,7 +322,8 @@ func WithCredentialsBundle(b credentials.Bundle) DialOption {
// WithTimeout returns a DialOption that configures a timeout for dialing a // WithTimeout returns a DialOption that configures a timeout for dialing a
// ClientConn initially. This is valid if and only if WithBlock() is present. // ClientConn initially. This is valid if and only if WithBlock() is present.
// //
// Deprecated: use DialContext and context.WithTimeout instead. // Deprecated: use DialContext and context.WithTimeout instead. Will be
// supported throughout 1.x.
func WithTimeout(d time.Duration) DialOption { func WithTimeout(d time.Duration) DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.timeout = d o.timeout = d
@ -352,7 +350,8 @@ func init() {
// is returned by f, gRPC checks the error's Temporary() method to decide if it // is returned by f, gRPC checks the error's Temporary() method to decide if it
// should try to reconnect to the network address. // should try to reconnect to the network address.
// //
// Deprecated: use WithContextDialer instead // Deprecated: use WithContextDialer instead. Will be supported throughout
// 1.x.
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption { func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
return WithContextDialer( return WithContextDialer(
func(ctx context.Context, addr string) (net.Conn, error) { func(ctx context.Context, addr string) (net.Conn, error) {
@ -414,6 +413,17 @@ func WithUnaryInterceptor(f UnaryClientInterceptor) DialOption {
}) })
} }
// WithChainUnaryInterceptor returns a DialOption that specifies the chained
// interceptor for unary RPCs. The first interceptor will be the outer most,
// while the last interceptor will be the inner most wrapper around the real call.
// All interceptors added by this method will be chained, and the interceptor
// defined by WithUnaryInterceptor will always be prepended to the chain.
func WithChainUnaryInterceptor(interceptors ...UnaryClientInterceptor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.chainUnaryInts = append(o.chainUnaryInts, interceptors...)
})
}
// WithStreamInterceptor returns a DialOption that specifies the interceptor for // WithStreamInterceptor returns a DialOption that specifies the interceptor for
// streaming RPCs. // streaming RPCs.
func WithStreamInterceptor(f StreamClientInterceptor) DialOption { func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
@ -422,6 +432,17 @@ func WithStreamInterceptor(f StreamClientInterceptor) DialOption {
}) })
} }
// WithChainStreamInterceptor returns a DialOption that specifies the chained
// interceptor for unary RPCs. The first interceptor will be the outer most,
// while the last interceptor will be the inner most wrapper around the real call.
// All interceptors added by this method will be chained, and the interceptor
// defined by WithStreamInterceptor will always be prepended to the chain.
func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.chainStreamInts = append(o.chainStreamInts, interceptors...)
})
}
// WithAuthority returns a DialOption that specifies the value to be used as the // WithAuthority returns a DialOption that specifies the value to be used as the
// :authority pseudo-header. This value only works with WithInsecure and has no // :authority pseudo-header. This value only works with WithInsecure and has no
// effect if TransportCredentials are present. // effect if TransportCredentials are present.
@ -440,12 +461,12 @@ func WithChannelzParentID(id int64) DialOption {
}) })
} }
// WithDisableServiceConfig returns a DialOption that causes grpc to ignore any // WithDisableServiceConfig returns a DialOption that causes gRPC to ignore any
// service config provided by the resolver and provides a hint to the resolver // service config provided by the resolver and provides a hint to the resolver
// to not fetch service configs. // to not fetch service configs.
// //
// Note that, this dial option only disables service config from resolver. If // Note that this dial option only disables service config from resolver. If
// default service config is provided, grpc will use the default service config. // default service config is provided, gRPC will use the default service config.
func WithDisableServiceConfig() DialOption { func WithDisableServiceConfig() DialOption {
return newFuncDialOption(func(o *dialOptions) { return newFuncDialOption(func(o *dialOptions) {
o.disableServiceConfig = true o.disableServiceConfig = true
@ -454,8 +475,10 @@ func WithDisableServiceConfig() DialOption {
// WithDefaultServiceConfig returns a DialOption that configures the default // WithDefaultServiceConfig returns a DialOption that configures the default
// service config, which will be used in cases where: // service config, which will be used in cases where:
// 1. WithDisableServiceConfig is called. //
// 2. Resolver does not return service config or if the resolver gets and invalid config. // 1. WithDisableServiceConfig is also used.
// 2. Resolver does not return a service config or if the resolver returns an
// invalid service config.
// //
// This API is EXPERIMENTAL. // This API is EXPERIMENTAL.
func WithDefaultServiceConfig(s string) DialOption { func WithDefaultServiceConfig(s string) DialOption {
@ -511,7 +534,6 @@ func withHealthCheckFunc(f internal.HealthChecker) DialOption {
func defaultDialOptions() dialOptions { func defaultDialOptions() dialOptions {
return dialOptions{ return dialOptions{
disableRetry: !envconfig.Retry, disableRetry: !envconfig.Retry,
reqHandshake: envconfig.RequireHandshake,
healthCheckFunc: internal.HealthCheckFunc, healthCheckFunc: internal.HealthCheckFunc,
copts: transport.ConnectOptions{ copts: transport.ConnectOptions{
WriteBufferSize: defaultWriteBufSize, WriteBufferSize: defaultWriteBufSize,

View file

@ -7,13 +7,13 @@ require (
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b
github.com/golang/mock v1.1.1 github.com/golang/mock v1.1.1
github.com/golang/protobuf v1.2.0 github.com/golang/protobuf v1.2.0
github.com/google/go-cmp v0.2.0
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3 golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3
golang.org/x/net v0.0.0-20190311183353-d8887717615a golang.org/x/net v0.0.0-20190311183353-d8887717615a
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f // indirect
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a
golang.org/x/tools v0.0.0-20190311212946-11955173bddd golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135
google.golang.org/appengine v1.1.0 // indirect google.golang.org/appengine v1.1.0 // indirect
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099 honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc
) )

View file

@ -26,6 +26,7 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
healthpb "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/internal" "google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/backoff" "google.golang.org/grpc/internal/backoff"
@ -51,7 +52,11 @@ func init() {
internal.HealthCheckFunc = clientHealthCheck internal.HealthCheckFunc = clientHealthCheck
} }
func clientHealthCheck(ctx context.Context, newStream func() (interface{}, error), reportHealth func(bool), service string) error { const healthCheckMethod = "/grpc.health.v1.Health/Watch"
// This function implements the protocol defined at:
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
func clientHealthCheck(ctx context.Context, newStream func(string) (interface{}, error), setConnectivityState func(connectivity.State), service string) error {
tryCnt := 0 tryCnt := 0
retryConnection: retryConnection:
@ -65,7 +70,8 @@ retryConnection:
if ctx.Err() != nil { if ctx.Err() != nil {
return nil return nil
} }
rawS, err := newStream() setConnectivityState(connectivity.Connecting)
rawS, err := newStream(healthCheckMethod)
if err != nil { if err != nil {
continue retryConnection continue retryConnection
} }
@ -73,7 +79,7 @@ retryConnection:
s, ok := rawS.(grpc.ClientStream) s, ok := rawS.(grpc.ClientStream)
// Ideally, this should never happen. But if it happens, the server is marked as healthy for LBing purposes. // Ideally, this should never happen. But if it happens, the server is marked as healthy for LBing purposes.
if !ok { if !ok {
reportHealth(true) setConnectivityState(connectivity.Ready)
return fmt.Errorf("newStream returned %v (type %T); want grpc.ClientStream", rawS, rawS) return fmt.Errorf("newStream returned %v (type %T); want grpc.ClientStream", rawS, rawS)
} }
@ -89,19 +95,23 @@ retryConnection:
// Reports healthy for the LBing purposes if health check is not implemented in the server. // Reports healthy for the LBing purposes if health check is not implemented in the server.
if status.Code(err) == codes.Unimplemented { if status.Code(err) == codes.Unimplemented {
reportHealth(true) setConnectivityState(connectivity.Ready)
return err return err
} }
// Reports unhealthy if server's Watch method gives an error other than UNIMPLEMENTED. // Reports unhealthy if server's Watch method gives an error other than UNIMPLEMENTED.
if err != nil { if err != nil {
reportHealth(false) setConnectivityState(connectivity.TransientFailure)
continue retryConnection continue retryConnection
} }
// As a message has been received, removes the need for backoff for the next retry by reseting the try count. // As a message has been received, removes the need for backoff for the next retry by reseting the try count.
tryCnt = 0 tryCnt = 0
reportHealth(resp.Status == healthpb.HealthCheckResponse_SERVING) if resp.Status == healthpb.HealthCheckResponse_SERVING {
setConnectivityState(connectivity.Ready)
} else {
setConnectivityState(connectivity.TransientFailure)
}
} }
} }
} }

View file

@ -24,6 +24,7 @@
package channelz package channelz
import ( import (
"fmt"
"sort" "sort"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -95,9 +96,14 @@ func (d *dbWrapper) get() *channelMap {
// NewChannelzStorage initializes channelz data storage and id generator. // NewChannelzStorage initializes channelz data storage and id generator.
// //
// This function returns a cleanup function to wait for all channelz state to be reset by the
// grpc goroutines when those entities get closed. By using this cleanup function, we make sure tests
// don't mess up each other, i.e. lingering goroutine from previous test doing entity removal happen
// to remove some entity just register by the new test, since the id space is the same.
//
// Note: This function is exported for testing purpose only. User should not call // Note: This function is exported for testing purpose only. User should not call
// it in most cases. // it in most cases.
func NewChannelzStorage() { func NewChannelzStorage() (cleanup func() error) {
db.set(&channelMap{ db.set(&channelMap{
topLevelChannels: make(map[int64]struct{}), topLevelChannels: make(map[int64]struct{}),
channels: make(map[int64]*channel), channels: make(map[int64]*channel),
@ -107,6 +113,28 @@ func NewChannelzStorage() {
subChannels: make(map[int64]*subChannel), subChannels: make(map[int64]*subChannel),
}) })
idGen.reset() idGen.reset()
return func() error {
var err error
cm := db.get()
if cm == nil {
return nil
}
for i := 0; i < 1000; i++ {
cm.mu.Lock()
if len(cm.topLevelChannels) == 0 && len(cm.servers) == 0 && len(cm.channels) == 0 && len(cm.subChannels) == 0 && len(cm.listenSockets) == 0 && len(cm.normalSockets) == 0 {
cm.mu.Unlock()
// all things stored in the channelz map have been cleared.
return nil
}
cm.mu.Unlock()
time.Sleep(10 * time.Millisecond)
}
cm.mu.Lock()
err = fmt.Errorf("after 10s the channelz map has not been cleaned up yet, topchannels: %d, servers: %d, channels: %d, subchannels: %d, listen sockets: %d, normal sockets: %d", len(cm.topLevelChannels), len(cm.servers), len(cm.channels), len(cm.subChannels), len(cm.listenSockets), len(cm.normalSockets))
cm.mu.Unlock()
return err
}
} }
// GetTopChannels returns a slice of top channel's ChannelMetric, along with a // GetTopChannels returns a slice of top channel's ChannelMetric, along with a

View file

@ -25,40 +25,11 @@ import (
) )
const ( const (
prefix = "GRPC_GO_" prefix = "GRPC_GO_"
retryStr = prefix + "RETRY" retryStr = prefix + "RETRY"
requireHandshakeStr = prefix + "REQUIRE_HANDSHAKE"
)
// RequireHandshakeSetting describes the settings for handshaking.
type RequireHandshakeSetting int
const (
// RequireHandshakeOn indicates to wait for handshake before considering a
// connection ready/successful.
RequireHandshakeOn RequireHandshakeSetting = iota
// RequireHandshakeOff indicates to not wait for handshake before
// considering a connection ready/successful.
RequireHandshakeOff
) )
var ( var (
// Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on". // Retry is set if retry is explicitly enabled via "GRPC_GO_RETRY=on".
Retry = strings.EqualFold(os.Getenv(retryStr), "on") Retry = strings.EqualFold(os.Getenv(retryStr), "on")
// RequireHandshake is set based upon the GRPC_GO_REQUIRE_HANDSHAKE
// environment variable.
//
// Will be removed after the 1.18 release.
RequireHandshake = RequireHandshakeOn
) )
func init() {
switch strings.ToLower(os.Getenv(requireHandshakeStr)) {
case "on":
fallthrough
default:
RequireHandshake = RequireHandshakeOn
case "off":
RequireHandshake = RequireHandshakeOff
}
}

View file

@ -23,6 +23,8 @@ package internal
import ( import (
"context" "context"
"time" "time"
"google.golang.org/grpc/connectivity"
) )
var ( var (
@ -37,10 +39,25 @@ var (
// KeepaliveMinPingTime is the minimum ping interval. This must be 10s by // KeepaliveMinPingTime is the minimum ping interval. This must be 10s by
// default, but tests may wish to set it lower for convenience. // default, but tests may wish to set it lower for convenience.
KeepaliveMinPingTime = 10 * time.Second KeepaliveMinPingTime = 10 * time.Second
// ParseServiceConfig is a function to parse JSON service configs into
// opaque data structures.
ParseServiceConfig func(sc string) (interface{}, error)
// StatusRawProto is exported by status/status.go. This func returns a
// pointer to the wrapped Status proto for a given status.Status without a
// call to proto.Clone(). The returned Status proto should not be mutated by
// the caller.
StatusRawProto interface{} // func (*status.Status) *spb.Status
) )
// HealthChecker defines the signature of the client-side LB channel health checking function. // HealthChecker defines the signature of the client-side LB channel health checking function.
type HealthChecker func(ctx context.Context, newStream func() (interface{}, error), reportHealth func(bool), serviceName string) error //
// The implementation is expected to create a health checking RPC stream by
// calling newStream(), watch for the health status of serviceName, and report
// it's health back by calling setConnectivityState().
//
// The health checking protocol is defined at:
// https://github.com/grpc/grpc/blob/master/doc/health-checking.md
type HealthChecker func(ctx context.Context, newStream func(string) (interface{}, error), setConnectivityState func(connectivity.State), serviceName string) error
const ( const (
// CredsBundleModeFallback switches GoogleDefaultCreds to fallback mode. // CredsBundleModeFallback switches GoogleDefaultCreds to fallback mode.

View file

@ -23,6 +23,7 @@ import (
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
"sync/atomic"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
@ -84,12 +85,24 @@ func (il *itemList) isEmpty() bool {
// the control buffer of transport. They represent different aspects of // the control buffer of transport. They represent different aspects of
// control tasks, e.g., flow control, settings, streaming resetting, etc. // control tasks, e.g., flow control, settings, streaming resetting, etc.
// maxQueuedTransportResponseFrames is the most queued "transport response"
// frames we will buffer before preventing new reads from occurring on the
// transport. These are control frames sent in response to client requests,
// such as RST_STREAM due to bad headers or settings acks.
const maxQueuedTransportResponseFrames = 50
type cbItem interface {
isTransportResponseFrame() bool
}
// registerStream is used to register an incoming stream with loopy writer. // registerStream is used to register an incoming stream with loopy writer.
type registerStream struct { type registerStream struct {
streamID uint32 streamID uint32
wq *writeQuota wq *writeQuota
} }
func (*registerStream) isTransportResponseFrame() bool { return false }
// headerFrame is also used to register stream on the client-side. // headerFrame is also used to register stream on the client-side.
type headerFrame struct { type headerFrame struct {
streamID uint32 streamID uint32
@ -102,6 +115,10 @@ type headerFrame struct {
onOrphaned func(error) // Valid on client-side onOrphaned func(error) // Valid on client-side
} }
func (h *headerFrame) isTransportResponseFrame() bool {
return h.cleanup != nil && h.cleanup.rst // Results in a RST_STREAM
}
type cleanupStream struct { type cleanupStream struct {
streamID uint32 streamID uint32
rst bool rst bool
@ -109,6 +126,8 @@ type cleanupStream struct {
onWrite func() onWrite func()
} }
func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM
type dataFrame struct { type dataFrame struct {
streamID uint32 streamID uint32
endStream bool endStream bool
@ -119,27 +138,41 @@ type dataFrame struct {
onEachWrite func() onEachWrite func()
} }
func (*dataFrame) isTransportResponseFrame() bool { return false }
type incomingWindowUpdate struct { type incomingWindowUpdate struct {
streamID uint32 streamID uint32
increment uint32 increment uint32
} }
func (*incomingWindowUpdate) isTransportResponseFrame() bool { return false }
type outgoingWindowUpdate struct { type outgoingWindowUpdate struct {
streamID uint32 streamID uint32
increment uint32 increment uint32
} }
func (*outgoingWindowUpdate) isTransportResponseFrame() bool {
return false // window updates are throttled by thresholds
}
type incomingSettings struct { type incomingSettings struct {
ss []http2.Setting ss []http2.Setting
} }
func (*incomingSettings) isTransportResponseFrame() bool { return true } // Results in a settings ACK
type outgoingSettings struct { type outgoingSettings struct {
ss []http2.Setting ss []http2.Setting
} }
func (*outgoingSettings) isTransportResponseFrame() bool { return false }
type incomingGoAway struct { type incomingGoAway struct {
} }
func (*incomingGoAway) isTransportResponseFrame() bool { return false }
type goAway struct { type goAway struct {
code http2.ErrCode code http2.ErrCode
debugData []byte debugData []byte
@ -147,15 +180,21 @@ type goAway struct {
closeConn bool closeConn bool
} }
func (*goAway) isTransportResponseFrame() bool { return false }
type ping struct { type ping struct {
ack bool ack bool
data [8]byte data [8]byte
} }
func (*ping) isTransportResponseFrame() bool { return true }
type outFlowControlSizeRequest struct { type outFlowControlSizeRequest struct {
resp chan uint32 resp chan uint32
} }
func (*outFlowControlSizeRequest) isTransportResponseFrame() bool { return false }
type outStreamState int type outStreamState int
const ( const (
@ -238,6 +277,14 @@ type controlBuffer struct {
consumerWaiting bool consumerWaiting bool
list *itemList list *itemList
err error err error
// transportResponseFrames counts the number of queued items that represent
// the response of an action initiated by the peer. trfChan is created
// when transportResponseFrames >= maxQueuedTransportResponseFrames and is
// closed and nilled when transportResponseFrames drops below the
// threshold. Both fields are protected by mu.
transportResponseFrames int
trfChan atomic.Value // *chan struct{}
} }
func newControlBuffer(done <-chan struct{}) *controlBuffer { func newControlBuffer(done <-chan struct{}) *controlBuffer {
@ -248,12 +295,24 @@ func newControlBuffer(done <-chan struct{}) *controlBuffer {
} }
} }
func (c *controlBuffer) put(it interface{}) error { // throttle blocks if there are too many incomingSettings/cleanupStreams in the
// controlbuf.
func (c *controlBuffer) throttle() {
ch, _ := c.trfChan.Load().(*chan struct{})
if ch != nil {
select {
case <-*ch:
case <-c.done:
}
}
}
func (c *controlBuffer) put(it cbItem) error {
_, err := c.executeAndPut(nil, it) _, err := c.executeAndPut(nil, it)
return err return err
} }
func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{}) (bool, error) { func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it cbItem) (bool, error) {
var wakeUp bool var wakeUp bool
c.mu.Lock() c.mu.Lock()
if c.err != nil { if c.err != nil {
@ -271,6 +330,15 @@ func (c *controlBuffer) executeAndPut(f func(it interface{}) bool, it interface{
c.consumerWaiting = false c.consumerWaiting = false
} }
c.list.enqueue(it) c.list.enqueue(it)
if it.isTransportResponseFrame() {
c.transportResponseFrames++
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are adding the frame that puts us over the threshold; create
// a throttling channel.
ch := make(chan struct{})
c.trfChan.Store(&ch)
}
}
c.mu.Unlock() c.mu.Unlock()
if wakeUp { if wakeUp {
select { select {
@ -304,7 +372,17 @@ func (c *controlBuffer) get(block bool) (interface{}, error) {
return nil, c.err return nil, c.err
} }
if !c.list.isEmpty() { if !c.list.isEmpty() {
h := c.list.dequeue() h := c.list.dequeue().(cbItem)
if h.isTransportResponseFrame() {
if c.transportResponseFrames == maxQueuedTransportResponseFrames {
// We are removing the frame that put us over the
// threshold; close and clear the throttling channel.
ch := c.trfChan.Load().(*chan struct{})
close(*ch)
c.trfChan.Store((*chan struct{})(nil))
}
c.transportResponseFrames--
}
c.mu.Unlock() c.mu.Unlock()
return h, nil return h, nil
} }

View file

@ -149,6 +149,7 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
n = uint32(math.MaxInt32) n = uint32(math.MaxInt32)
} }
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock()
// estSenderQuota is the receiver's view of the maximum number of bytes the sender // estSenderQuota is the receiver's view of the maximum number of bytes the sender
// can send without a window update. // can send without a window update.
estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate)) estSenderQuota := int32(f.limit - (f.pendingData + f.pendingUpdate))
@ -169,10 +170,8 @@ func (f *inFlow) maybeAdjust(n uint32) uint32 {
// is padded; We will fallback on the current available window(at least a 1/4th of the limit). // is padded; We will fallback on the current available window(at least a 1/4th of the limit).
f.delta = n f.delta = n
} }
f.mu.Unlock()
return f.delta return f.delta
} }
f.mu.Unlock()
return 0 return 0
} }

View file

@ -24,6 +24,7 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -347,7 +348,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
ht.stats.HandleRPC(s.ctx, inHeader) ht.stats.HandleRPC(s.ctx, inHeader)
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
windowHandler: func(int) {}, windowHandler: func(int) {},
} }
@ -361,7 +362,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
for buf := make([]byte, readSize); ; { for buf := make([]byte, readSize); ; {
n, err := req.Body.Read(buf) n, err := req.Body.Read(buf)
if n > 0 { if n > 0 {
s.buf.put(recvMsg{data: buf[:n:n]}) s.buf.put(recvMsg{buffer: bytes.NewBuffer(buf[:n:n])})
buf = buf[n:] buf = buf[n:]
} }
if err != nil { if err != nil {

View file

@ -117,6 +117,8 @@ type http2Client struct {
onGoAway func(GoAwayReason) onGoAway func(GoAwayReason)
onClose func() onClose func()
bufferPool *bufferPool
} }
func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) {
@ -249,6 +251,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne
onGoAway: onGoAway, onGoAway: onGoAway,
onClose: onClose, onClose: onClose,
keepaliveEnabled: keepaliveEnabled, keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
} }
t.controlBuf = newControlBuffer(t.ctxDone) t.controlBuf = newControlBuffer(t.ctxDone)
if opts.InitialWindowSize >= defaultWindowSize { if opts.InitialWindowSize >= defaultWindowSize {
@ -367,6 +370,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream {
closeStream: func(err error) { closeStream: func(err error) {
t.CloseStream(s, err) t.CloseStream(s, err)
}, },
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -437,6 +441,15 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok { if md, added, ok := metadata.FromOutgoingContextRaw(ctx); ok {
var k string var k string
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
for _, vv := range added { for _, vv := range added {
for i, v := range vv { for i, v := range vv {
if i%2 == 0 { if i%2 == 0 {
@ -450,15 +463,6 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)}) headerFields = append(headerFields, hpack.HeaderField{Name: strings.ToLower(k), Value: encodeMetadataHeader(k, v)})
} }
} }
for k, vv := range md {
// HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set.
if isReservedHeader(k) {
continue
}
for _, v := range vv {
headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
}
}
} }
if md, ok := t.md.(*metadata.MD); ok { if md, ok := t.md.(*metadata.MD); ok {
for k, vv := range *md { for k, vv := range *md {
@ -489,6 +493,9 @@ func (t *http2Client) createAudience(callHdr *CallHdr) string {
} }
func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) { func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[string]string, error) {
if len(t.perRPCCreds) == 0 {
return nil, nil
}
authData := map[string]string{} authData := map[string]string{}
for _, c := range t.perRPCCreds { for _, c := range t.perRPCCreds {
data, err := c.GetRequestMetadata(ctx, audience) data, err := c.GetRequestMetadata(ctx, audience)
@ -509,7 +516,7 @@ func (t *http2Client) getTrAuthData(ctx context.Context, audience string) (map[s
} }
func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) { func (t *http2Client) getCallAuthData(ctx context.Context, audience string, callHdr *CallHdr) (map[string]string, error) {
callAuthData := map[string]string{} var callAuthData map[string]string
// Check if credentials.PerRPCCredentials were provided via call options. // Check if credentials.PerRPCCredentials were provided via call options.
// Note: if these credentials are provided both via dial options and call // Note: if these credentials are provided both via dial options and call
// options, then both sets of credentials will be applied. // options, then both sets of credentials will be applied.
@ -521,6 +528,7 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "transport: %v", err) return nil, status.Errorf(codes.Internal, "transport: %v", err)
} }
callAuthData = make(map[string]string, len(data))
for k, v := range data { for k, v := range data {
// Capital header names are illegal in HTTP/2 // Capital header names are illegal in HTTP/2
k = strings.ToLower(k) k = strings.ToLower(k)
@ -549,10 +557,9 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
s.write(recvMsg{err: err}) s.write(recvMsg{err: err})
close(s.done) close(s.done)
// If headerChan isn't closed, then close it. // If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 { if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
close(s.headerChan) close(s.headerChan)
} }
} }
hdr := &headerFrame{ hdr := &headerFrame{
hf: headerFields, hf: headerFields,
@ -713,7 +720,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
s.write(recvMsg{err: err}) s.write(recvMsg{err: err})
} }
// If headerChan isn't closed, then close it. // If headerChan isn't closed, then close it.
if atomic.SwapUint32(&s.headerDone, 1) == 0 { if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.noHeaders = true s.noHeaders = true
close(s.headerChan) close(s.headerChan)
} }
@ -765,6 +772,9 @@ func (t *http2Client) Close() error {
t.mu.Unlock() t.mu.Unlock()
return nil return nil
} }
// Call t.onClose before setting the state to closing to prevent the client
// from attempting to create new streams ASAP.
t.onClose()
t.state = closing t.state = closing
streams := t.activeStreams streams := t.activeStreams
t.activeStreams = nil t.activeStreams = nil
@ -785,7 +795,6 @@ func (t *http2Client) Close() error {
} }
t.statsHandler.HandleConn(t.ctx, connEnd) t.statsHandler.HandleConn(t.ctx, connEnd)
} }
t.onClose()
return err return err
} }
@ -794,21 +803,21 @@ func (t *http2Client) Close() error {
// stream is closed. If there are no active streams, the transport is closed // stream is closed. If there are no active streams, the transport is closed
// immediately. This does nothing if the transport is already draining or // immediately. This does nothing if the transport is already draining or
// closing. // closing.
func (t *http2Client) GracefulClose() error { func (t *http2Client) GracefulClose() {
t.mu.Lock() t.mu.Lock()
// Make sure we move to draining only from active. // Make sure we move to draining only from active.
if t.state == draining || t.state == closing { if t.state == draining || t.state == closing {
t.mu.Unlock() t.mu.Unlock()
return nil return
} }
t.state = draining t.state = draining
active := len(t.activeStreams) active := len(t.activeStreams)
t.mu.Unlock() t.mu.Unlock()
if active == 0 { if active == 0 {
return t.Close() t.Close()
return
} }
t.controlBuf.put(&incomingGoAway{}) t.controlBuf.put(&incomingGoAway{})
return nil
} }
// Write formats the data into HTTP2 data frame(s) and sends it out. The caller // Write formats the data into HTTP2 data frame(s) and sends it out. The caller
@ -946,9 +955,10 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
data := make([]byte, len(f.Data())) buffer := t.bufferPool.get()
copy(data, f.Data()) buffer.Reset()
s.write(recvMsg{data: data}) buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
} }
} }
// The server has closed the stream without sending trailers. Record that // The server has closed the stream without sending trailers. Record that
@ -973,9 +983,9 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
statusCode = codes.Unknown statusCode = codes.Unknown
} }
if statusCode == codes.Canceled { if statusCode == codes.Canceled {
// Our deadline was already exceeded, and that was likely the cause of if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
// this cancelation. Alter the status code accordingly. // Our deadline was already exceeded, and that was likely the cause
if d, ok := s.ctx.Deadline(); ok && d.After(time.Now()) { // of this cancelation. Alter the status code accordingly.
statusCode = codes.DeadlineExceeded statusCode = codes.DeadlineExceeded
} }
} }
@ -1080,11 +1090,12 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
default: default:
t.setGoAwayReason(f) t.setGoAwayReason(f)
close(t.goAway) close(t.goAway)
t.state = draining
t.controlBuf.put(&incomingGoAway{}) t.controlBuf.put(&incomingGoAway{})
// Notify the clientconn about the GOAWAY before we set the state to
// This has to be a new goroutine because we're still using the current goroutine to read in the transport. // draining, to allow the client to stop attempting to create streams
// before disallowing new streams on this connection.
t.onGoAway(t.goAwayReason) t.onGoAway(t.goAwayReason)
t.state = draining
} }
// All streams with IDs greater than the GoAwayId // All streams with IDs greater than the GoAwayId
// and smaller than the previous GoAway ID should be killed. // and smaller than the previous GoAway ID should be killed.
@ -1142,26 +1153,24 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
} }
endStream := frame.StreamEnded() endStream := frame.StreamEnded()
atomic.StoreUint32(&s.bytesReceived, 1) atomic.StoreUint32(&s.bytesReceived, 1)
initialHeader := atomic.SwapUint32(&s.headerDone, 1) == 0 initialHeader := atomic.LoadUint32(&s.headerChanClosed) == 0
if !initialHeader && !endStream { if !initialHeader && !endStream {
// As specified by RFC 7540, a HEADERS frame (and associated CONTINUATION frames) can only appear // As specified by gRPC over HTTP2, a HEADERS frame (and associated CONTINUATION frames) can only appear at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
// at the start or end of a stream. Therefore, second HEADERS frame must have EOS bit set.
st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream") st := status.New(codes.Internal, "a HEADERS frame cannot appear in the middle of a stream")
t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false) t.closeStream(s, st.Err(), true, http2.ErrCodeProtocol, st, nil, false)
return return
} }
state := &decodeState{} state := &decodeState{}
// Initialize isGRPC value to be !initialHeader, since if a gRPC ResponseHeader has been received // Initialize isGRPC value to be !initialHeader, since if a gRPC Response-Headers has already been received, then it means that the peer is speaking gRPC and we are in gRPC mode.
// which indicates peer speaking gRPC, we are in gRPC mode.
state.data.isGRPC = !initialHeader state.data.isGRPC = !initialHeader
if err := state.decodeHeader(frame); err != nil { if err := state.decodeHeader(frame); err != nil {
t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream) t.closeStream(s, err, true, http2.ErrCodeProtocol, status.Convert(err), nil, endStream)
return return
} }
var isHeader bool isHeader := false
defer func() { defer func() {
if t.statsHandler != nil { if t.statsHandler != nil {
if isHeader { if isHeader {
@ -1180,10 +1189,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
} }
}() }()
// If headers haven't been received yet. // If headerChan hasn't been closed yet
if initialHeader { if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
if !endStream { if !endStream {
// Headers frame is ResponseHeader. // HEADERS frame block carries a Response-Headers.
isHeader = true isHeader = true
// These values can be set without any synchronization because // These values can be set without any synchronization because
// stream goroutine will read it only after seeing a closed // stream goroutine will read it only after seeing a closed
@ -1192,14 +1201,17 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
if len(state.data.mdata) > 0 { if len(state.data.mdata) > 0 {
s.header = state.data.mdata s.header = state.data.mdata
} }
close(s.headerChan) } else {
return // HEADERS frame block carries a Trailers-Only.
s.noHeaders = true
} }
// Headers frame is Trailers-only.
s.noHeaders = true
close(s.headerChan) close(s.headerChan)
} }
if !endStream {
return
}
// if client received END_STREAM from server while stream was still active, send RST_STREAM // if client received END_STREAM from server while stream was still active, send RST_STREAM
rst := s.getState() == streamActive rst := s.getState() == streamActive
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true) t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, state.status(), state.data.mdata, true)
@ -1233,6 +1245,7 @@ func (t *http2Client) reader() {
// loop to keep reading incoming messages on this transport. // loop to keep reading incoming messages on this transport.
for { for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
if t.keepaliveEnabled { if t.keepaliveEnabled {
atomic.CompareAndSwapUint32(&t.activity, 0, 1) atomic.CompareAndSwapUint32(&t.activity, 0, 1)
@ -1320,6 +1333,7 @@ func (t *http2Client) keepalive() {
timer.Reset(t.kp.Time) timer.Reset(t.kp.Time)
continue continue
} }
infof("transport: closing client transport due to idleness.")
t.Close() t.Close()
return return
case <-t.ctx.Done(): case <-t.ctx.Done():

View file

@ -35,9 +35,11 @@ import (
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@ -55,6 +57,9 @@ var (
// ErrHeaderListSizeLimitViolation indicates that the header list size is larger // ErrHeaderListSizeLimitViolation indicates that the header list size is larger
// than the limit set by peer. // than the limit set by peer.
ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer") ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer")
// statusRawProto is a function to get to the raw status proto wrapped in a
// status.Status without a proto.Clone().
statusRawProto = internal.StatusRawProto.(func(*status.Status) *spb.Status)
) )
// http2Server implements the ServerTransport interface with HTTP2. // http2Server implements the ServerTransport interface with HTTP2.
@ -119,6 +124,7 @@ type http2Server struct {
// Fields below are for channelz metric collection. // Fields below are for channelz metric collection.
channelzID int64 // channelz unique identification number channelzID int64 // channelz unique identification number
czData *channelzData czData *channelzData
bufferPool *bufferPool
} }
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is // newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
@ -220,6 +226,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err
kep: kep, kep: kep,
initialWindowSize: iwz, initialWindowSize: iwz,
czData: new(channelzData), czData: new(channelzData),
bufferPool: newBufferPool(),
} }
t.controlBuf = newControlBuffer(t.ctxDone) t.controlBuf = newControlBuffer(t.ctxDone)
if dynamicWindow { if dynamicWindow {
@ -405,9 +412,10 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ reader: &recvBufferReader{
ctx: s.ctx, ctx: s.ctx,
ctxDone: s.ctxDone, ctxDone: s.ctxDone,
recv: s.buf, recv: s.buf,
freeBuffer: t.bufferPool.put,
}, },
windowHandler: func(n int) { windowHandler: func(n int) {
t.updateWindow(s, uint32(n)) t.updateWindow(s, uint32(n))
@ -428,6 +436,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
defer close(t.readerDone) defer close(t.readerDone)
for { for {
t.controlBuf.throttle()
frame, err := t.framer.fr.ReadFrame() frame, err := t.framer.fr.ReadFrame()
atomic.StoreUint32(&t.activity, 1) atomic.StoreUint32(&t.activity, 1)
if err != nil { if err != nil {
@ -591,9 +600,10 @@ func (t *http2Server) handleData(f *http2.DataFrame) {
// guarantee f.Data() is consumed before the arrival of next frame. // guarantee f.Data() is consumed before the arrival of next frame.
// Can this copy be eliminated? // Can this copy be eliminated?
if len(f.Data()) > 0 { if len(f.Data()) > 0 {
data := make([]byte, len(f.Data())) buffer := t.bufferPool.get()
copy(data, f.Data()) buffer.Reset()
s.write(recvMsg{data: data}) buffer.Write(f.Data())
s.write(recvMsg{buffer: buffer})
} }
} }
if f.Header().Flags.Has(http2.FlagDataEndStream) { if f.Header().Flags.Has(http2.FlagDataEndStream) {
@ -757,6 +767,10 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
return nil return nil
} }
func (t *http2Server) setResetPingStrikes() {
atomic.StoreUint32(&t.resetPingStrikes, 1)
}
func (t *http2Server) writeHeaderLocked(s *Stream) error { func (t *http2Server) writeHeaderLocked(s *Stream) error {
// TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields // TODO(mmukhi): Benchmark if the performance gets better if count the metadata and other header fields
// first and create a slice of that exact size. // first and create a slice of that exact size.
@ -771,9 +785,7 @@ func (t *http2Server) writeHeaderLocked(s *Stream) error {
streamID: s.id, streamID: s.id,
hf: headerFields, hf: headerFields,
endStream: false, endStream: false,
onWrite: func() { onWrite: t.setResetPingStrikes,
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
}) })
if !success { if !success {
if err != nil { if err != nil {
@ -817,7 +829,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})
if p := st.Proto(); p != nil && len(p.Details) > 0 { if p := statusRawProto(st); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p) stBytes, err := proto.Marshal(p)
if err != nil { if err != nil {
// TODO: return error instead, when callers are able to handle it. // TODO: return error instead, when callers are able to handle it.
@ -833,9 +845,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
streamID: s.id, streamID: s.id,
hf: headerFields, hf: headerFields,
endStream: true, endStream: true,
onWrite: func() { onWrite: t.setResetPingStrikes,
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
} }
s.hdrMu.Unlock() s.hdrMu.Unlock()
success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader) success, err := t.controlBuf.execute(t.checkForHeaderListSize, trailingHeader)
@ -887,12 +897,10 @@ func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) e
hdr = append(hdr, data[:emptyLen]...) hdr = append(hdr, data[:emptyLen]...)
data = data[emptyLen:] data = data[emptyLen:]
df := &dataFrame{ df := &dataFrame{
streamID: s.id, streamID: s.id,
h: hdr, h: hdr,
d: data, d: data,
onEachWrite: func() { onEachWrite: t.setResetPingStrikes,
atomic.StoreUint32(&t.resetPingStrikes, 1)
},
} }
if err := s.wq.get(int32(len(hdr) + len(data))); err != nil { if err := s.wq.get(int32(len(hdr) + len(data))); err != nil {
select { select {
@ -958,6 +966,7 @@ func (t *http2Server) keepalive() {
select { select {
case <-maxAge.C: case <-maxAge.C:
// Close the connection after grace period. // Close the connection after grace period.
infof("transport: closing server transport due to maximum connection age.")
t.Close() t.Close()
// Resetting the timer so that the clean-up doesn't deadlock. // Resetting the timer so that the clean-up doesn't deadlock.
maxAge.Reset(infinity) maxAge.Reset(infinity)
@ -971,6 +980,7 @@ func (t *http2Server) keepalive() {
continue continue
} }
if pingSent { if pingSent {
infof("transport: closing server transport due to idleness.")
t.Close() t.Close()
// Resetting the timer so that the clean-up doesn't deadlock. // Resetting the timer so that the clean-up doesn't deadlock.
keepalive.Reset(infinity) keepalive.Reset(infinity)
@ -1019,13 +1029,7 @@ func (t *http2Server) Close() error {
} }
// deleteStream deletes the stream s from transport's active streams. // deleteStream deletes the stream s from transport's active streams.
func (t *http2Server) deleteStream(s *Stream, eosReceived bool) (oldState streamState) { func (t *http2Server) deleteStream(s *Stream, eosReceived bool) {
oldState = s.swapState(streamDone)
if oldState == streamDone {
// If the stream was already done, return.
return oldState
}
// In case stream sending and receiving are invoked in separate // In case stream sending and receiving are invoked in separate
// goroutines (e.g., bi-directional streaming), cancel needs to be // goroutines (e.g., bi-directional streaming), cancel needs to be
// called to interrupt the potential blocking on other goroutines. // called to interrupt the potential blocking on other goroutines.
@ -1047,15 +1051,13 @@ func (t *http2Server) deleteStream(s *Stream, eosReceived bool) (oldState stream
atomic.AddInt64(&t.czData.streamsFailed, 1) atomic.AddInt64(&t.czData.streamsFailed, 1)
} }
} }
return oldState
} }
// finishStream closes the stream and puts the trailing headerFrame into controlbuf. // finishStream closes the stream and puts the trailing headerFrame into controlbuf.
func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) { func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, hdr *headerFrame, eosReceived bool) {
oldState := t.deleteStream(s, eosReceived) oldState := s.swapState(streamDone)
// If the stream is already closed, then don't put trailing header to controlbuf.
if oldState == streamDone { if oldState == streamDone {
// If the stream was already done, return.
return return
} }
@ -1063,14 +1065,18 @@ func (t *http2Server) finishStream(s *Stream, rst bool, rstCode http2.ErrCode, h
streamID: s.id, streamID: s.id,
rst: rst, rst: rst,
rstCode: rstCode, rstCode: rstCode,
onWrite: func() {}, onWrite: func() {
t.deleteStream(s, eosReceived)
},
} }
t.controlBuf.put(hdr) t.controlBuf.put(hdr)
} }
// closeStream clears the footprint of a stream when the stream is not needed any more. // closeStream clears the footprint of a stream when the stream is not needed any more.
func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) { func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eosReceived bool) {
s.swapState(streamDone)
t.deleteStream(s, eosReceived) t.deleteStream(s, eosReceived)
t.controlBuf.put(&cleanupStream{ t.controlBuf.put(&cleanupStream{
streamID: s.id, streamID: s.id,
rst: rst, rst: rst,

View file

@ -22,6 +22,7 @@
package transport package transport
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@ -39,10 +40,32 @@ import (
"google.golang.org/grpc/tap" "google.golang.org/grpc/tap"
) )
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
},
}
}
func (p *bufferPool) get() *bytes.Buffer {
return p.pool.Get().(*bytes.Buffer)
}
func (p *bufferPool) put(b *bytes.Buffer) {
p.pool.Put(b)
}
// recvMsg represents the received msg from the transport. All transport // recvMsg represents the received msg from the transport. All transport
// protocol specific info has been removed. // protocol specific info has been removed.
type recvMsg struct { type recvMsg struct {
data []byte buffer *bytes.Buffer
// nil: received some data // nil: received some data
// io.EOF: stream is completed. data is nil. // io.EOF: stream is completed. data is nil.
// other non-nil error: transport failure. data is nil. // other non-nil error: transport failure. data is nil.
@ -117,8 +140,9 @@ type recvBufferReader struct {
ctx context.Context ctx context.Context
ctxDone <-chan struct{} // cache of ctx.Done() (for performance). ctxDone <-chan struct{} // cache of ctx.Done() (for performance).
recv *recvBuffer recv *recvBuffer
last []byte // Stores the remaining data in the previous calls. last *bytes.Buffer // Stores the remaining data in the previous calls.
err error err error
freeBuffer func(*bytes.Buffer)
} }
// Read reads the next len(p) bytes from last. If last is drained, it tries to // Read reads the next len(p) bytes from last. If last is drained, it tries to
@ -128,10 +152,13 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
if r.err != nil { if r.err != nil {
return 0, r.err return 0, r.err
} }
if r.last != nil && len(r.last) > 0 { if r.last != nil {
// Read remaining data left in last call. // Read remaining data left in last call.
copied := copy(p, r.last) copied, _ := r.last.Read(p)
r.last = r.last[copied:] if r.last.Len() == 0 {
r.freeBuffer(r.last)
r.last = nil
}
return copied, nil return copied, nil
} }
if r.closeStream != nil { if r.closeStream != nil {
@ -157,6 +184,19 @@ func (r *recvBufferReader) readClient(p []byte) (n int, err error) {
// r.readAdditional acts on that message and returns the necessary error. // r.readAdditional acts on that message and returns the necessary error.
select { select {
case <-r.ctxDone: case <-r.ctxDone:
// Note that this adds the ctx error to the end of recv buffer, and
// reads from the head. This will delay the error until recv buffer is
// empty, thus will delay ctx cancellation in Recv().
//
// It's done this way to fix a race between ctx cancel and trailer. The
// race was, stream.Recv() may return ctx error if ctxDone wins the
// race, but stream.Trailer() may return a non-nil md because the stream
// was not marked as done when trailer is received. This closeStream
// call will mark stream as done, thus fix the race.
//
// TODO: delaying ctx error seems like a unnecessary side effect. What
// we really want is to mark the stream as done, and return ctx error
// faster.
r.closeStream(ContextErr(r.ctx.Err())) r.closeStream(ContextErr(r.ctx.Err()))
m := <-r.recv.get() m := <-r.recv.get()
return r.readAdditional(m, p) return r.readAdditional(m, p)
@ -170,8 +210,13 @@ func (r *recvBufferReader) readAdditional(m recvMsg, p []byte) (n int, err error
if m.err != nil { if m.err != nil {
return 0, m.err return 0, m.err
} }
copied := copy(p, m.data) copied, _ := m.buffer.Read(p)
r.last = m.data[copied:] if m.buffer.Len() == 0 {
r.freeBuffer(m.buffer)
r.last = nil
} else {
r.last = m.buffer
}
return copied, nil return copied, nil
} }
@ -204,8 +249,8 @@ type Stream struct {
// is used to adjust flow control, if needed. // is used to adjust flow control, if needed.
requestRead func(int) requestRead func(int)
headerChan chan struct{} // closed to indicate the end of header metadata. headerChan chan struct{} // closed to indicate the end of header metadata.
headerDone uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times.
// hdrMu protects header and trailer metadata on the server-side. // hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex hdrMu sync.Mutex
@ -266,6 +311,14 @@ func (s *Stream) waitOnHeader() error {
} }
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
// We prefer success over failure when reading messages because we delay
// context error in stream.Read(). To keep behavior consistent, we also
// prefer success here.
select {
case <-s.headerChan:
return nil
default:
}
return ContextErr(s.ctx.Err()) return ContextErr(s.ctx.Err())
case <-s.headerChan: case <-s.headerChan:
return nil return nil
@ -578,9 +631,12 @@ type ClientTransport interface {
// is called only once. // is called only once.
Close() error Close() error
// GracefulClose starts to tear down the transport. It stops accepting // GracefulClose starts to tear down the transport: the transport will stop
// new RPCs and wait the completion of the pending RPCs. // accepting new RPCs and NewStream will return error. Once all streams are
GracefulClose() error // finished, the transport will close.
//
// It does not block.
GracefulClose()
// Write sends the data for the given stream. A nil stream indicates // Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole. // the write is to be performed on the transport as a whole.

View file

@ -17,9 +17,8 @@
*/ */
// Package naming defines the naming API and related data structures for gRPC. // Package naming defines the naming API and related data structures for gRPC.
// The interface is EXPERIMENTAL and may be subject to change.
// //
// Deprecated: please use package resolver. // This package is deprecated: please use package resolver instead.
package naming package naming
// Operation defines the corresponding operations for a name resolution change. // Operation defines the corresponding operations for a name resolution change.

View file

@ -120,6 +120,14 @@ func (bp *pickerWrapper) pick(ctx context.Context, failfast bool, opts balancer.
bp.mu.Unlock() bp.mu.Unlock()
select { select {
case <-ctx.Done(): case <-ctx.Done():
if connectionErr := bp.connectionError(); connectionErr != nil {
switch ctx.Err() {
case context.DeadlineExceeded:
return nil, nil, status.Errorf(codes.DeadlineExceeded, "latest connection error: %v", connectionErr)
case context.Canceled:
return nil, nil, status.Errorf(codes.Canceled, "latest connection error: %v", connectionErr)
}
}
return nil, nil, ctx.Err() return nil, nil, ctx.Err()
case <-ch: case <-ch:
} }

View file

@ -51,14 +51,18 @@ type pickfirstBalancer struct {
func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) { func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err error) {
if err != nil { if err != nil {
grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err) if grpclog.V(2) {
grpclog.Infof("pickfirstBalancer: HandleResolvedAddrs called with error %v", err)
}
return return
} }
if b.sc == nil { if b.sc == nil {
b.sc, err = b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{}) b.sc, err = b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{})
if err != nil { if err != nil {
//TODO(yuxuanli): why not change the cc state to Idle? //TODO(yuxuanli): why not change the cc state to Idle?
grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err) if grpclog.V(2) {
grpclog.Errorf("pickfirstBalancer: failed to NewSubConn: %v", err)
}
return return
} }
b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc}) b.cc.UpdateBalancerState(connectivity.Idle, &picker{sc: b.sc})
@ -70,9 +74,13 @@ func (b *pickfirstBalancer) HandleResolvedAddrs(addrs []resolver.Address, err er
} }
func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { func (b *pickfirstBalancer) HandleSubConnStateChange(sc balancer.SubConn, s connectivity.State) {
grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s) if grpclog.V(2) {
grpclog.Infof("pickfirstBalancer: HandleSubConnStateChange: %p, %v", sc, s)
}
if b.sc != sc { if b.sc != sc {
grpclog.Infof("pickfirstBalancer: ignored state change because sc is not recognized") if grpclog.V(2) {
grpclog.Infof("pickfirstBalancer: ignored state change because sc is not recognized")
}
return return
} }
if s == connectivity.Shutdown { if s == connectivity.Shutdown {

64
vendor/google.golang.org/grpc/preloader.go generated vendored Normal file
View file

@ -0,0 +1,64 @@
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpc
import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// PreparedMsg is responsible for creating a Marshalled and Compressed object.
//
// This API is EXPERIMENTAL.
type PreparedMsg struct {
// Struct for preparing msg before sending them
encodedData []byte
hdr []byte
payload []byte
}
// Encode marshalls and compresses the message using the codec and compressor for the stream.
func (p *PreparedMsg) Encode(s Stream, msg interface{}) error {
ctx := s.Context()
rpcInfo, ok := rpcInfoFromContext(ctx)
if !ok {
return status.Errorf(codes.Internal, "grpc: unable to get rpcInfo")
}
// check if the context has the relevant information to prepareMsg
if rpcInfo.preloaderInfo == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil")
}
if rpcInfo.preloaderInfo.codec == nil {
return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil")
}
// prepare the msg
data, err := encode(rpcInfo.preloaderInfo.codec, msg)
if err != nil {
return err
}
p.encodedData = data
compData, err := compress(data, rpcInfo.preloaderInfo.cp, rpcInfo.preloaderInfo.comp)
if err != nil {
return err
}
p.hdr, p.payload = msgHeader(data, compData)
return nil
}

View file

@ -66,6 +66,9 @@ var (
var ( var (
defaultResolver netResolver = net.DefaultResolver defaultResolver netResolver = net.DefaultResolver
// To prevent excessive re-resolution, we enforce a rate limit on DNS
// resolution requests.
minDNSResRate = 30 * time.Second
) )
var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) { var customAuthorityDialler = func(authority string) func(ctx context.Context, network, address string) (net.Conn, error) {
@ -241,7 +244,13 @@ func (d *dnsResolver) watcher() {
return return
case <-d.t.C: case <-d.t.C:
case <-d.rn: case <-d.rn:
if !d.t.Stop() {
// Before resetting a timer, it should be stopped to prevent racing with
// reads on it's channel.
<-d.t.C
}
} }
result, sc := d.lookup() result, sc := d.lookup()
// Next lookup should happen within an interval defined by d.freq. It may be // Next lookup should happen within an interval defined by d.freq. It may be
// more often due to exponential retry on empty address list. // more often due to exponential retry on empty address list.
@ -254,6 +263,16 @@ func (d *dnsResolver) watcher() {
} }
d.cc.NewServiceConfig(sc) d.cc.NewServiceConfig(sc)
d.cc.NewAddress(result) d.cc.NewAddress(result)
// Sleep to prevent excessive re-resolutions. Incoming resolution requests
// will be queued in d.rn.
t := time.NewTimer(minDNSResRate)
select {
case <-t.C:
case <-d.ctx.Done():
t.Stop()
return
}
} }
} }

View file

@ -20,6 +20,10 @@
// All APIs in this package are experimental. // All APIs in this package are experimental.
package resolver package resolver
import (
"google.golang.org/grpc/serviceconfig"
)
var ( var (
// m is a map from scheme to resolver builder. // m is a map from scheme to resolver builder.
m = make(map[string]Builder) m = make(map[string]Builder)
@ -100,11 +104,12 @@ type BuildOption struct {
// State contains the current Resolver state relevant to the ClientConn. // State contains the current Resolver state relevant to the ClientConn.
type State struct { type State struct {
Addresses []Address // Resolved addresses for the target Addresses []Address // Resolved addresses for the target
ServiceConfig string // JSON representation of the service config // ServiceConfig is the parsed service config; obtained from
// serviceconfig.Parse.
ServiceConfig serviceconfig.Config
// TODO: add Err error // TODO: add Err error
// TODO: add ParsedServiceConfig interface{}
} }
// ClientConn contains the callbacks for resolver to notify any updates // ClientConn contains the callbacks for resolver to notify any updates
@ -132,6 +137,21 @@ type ClientConn interface {
// Target represents a target for gRPC, as specified in: // Target represents a target for gRPC, as specified in:
// https://github.com/grpc/grpc/blob/master/doc/naming.md. // https://github.com/grpc/grpc/blob/master/doc/naming.md.
// It is parsed from the target string that gets passed into Dial or DialContext by the user. And
// grpc passes it to the resolver and the balancer.
//
// If the target follows the naming spec, and the parsed scheme is registered with grpc, we will
// parse the target string according to the spec. e.g. "dns://some_authority/foo.bar" will be parsed
// into &Target{Scheme: "dns", Authority: "some_authority", Endpoint: "foo.bar"}
//
// If the target does not contain a scheme, we will apply the default scheme, and set the Target to
// be the full target string. e.g. "foo.bar" will be parsed into
// &Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "foo.bar"}.
//
// If the parsed scheme is not registered (i.e. no corresponding resolver available to resolve the
// endpoint), we set the Scheme to be the default scheme, and set the Endpoint to be the full target
// string. e.g. target string "unknown_scheme://authority/endpoint" will be parsed into
// &Target{Scheme: resolver.GetDefaultScheme(), Endpoint: "unknown_scheme://authority/endpoint"}.
type Target struct { type Target struct {
Scheme string Scheme string
Authority string Authority string

View file

@ -138,19 +138,22 @@ func (ccr *ccResolverWrapper) NewServiceConfig(sc string) {
return return
} }
grpclog.Infof("ccResolverWrapper: got new service config: %v", sc) grpclog.Infof("ccResolverWrapper: got new service config: %v", sc)
if channelz.IsOn() { c, err := parseServiceConfig(sc)
ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: sc}) if err != nil {
return
} }
ccr.curState.ServiceConfig = sc if channelz.IsOn() {
ccr.addChannelzTraceEvent(resolver.State{Addresses: ccr.curState.Addresses, ServiceConfig: c})
}
ccr.curState.ServiceConfig = c
ccr.cc.updateResolverState(ccr.curState) ccr.cc.updateResolverState(ccr.curState)
} }
func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) { func (ccr *ccResolverWrapper) addChannelzTraceEvent(s resolver.State) {
if s.ServiceConfig == ccr.curState.ServiceConfig && (len(ccr.curState.Addresses) == 0) == (len(s.Addresses) == 0) {
return
}
var updates []string var updates []string
if s.ServiceConfig != ccr.curState.ServiceConfig { oldSC, oldOK := ccr.curState.ServiceConfig.(*ServiceConfig)
newSC, newOK := s.ServiceConfig.(*ServiceConfig)
if oldOK != newOK || (oldOK && newOK && oldSC.rawJSONString != newSC.rawJSONString) {
updates = append(updates, "service config updated") updates = append(updates, "service config updated")
} }
if len(ccr.curState.Addresses) > 0 && len(s.Addresses) == 0 { if len(ccr.curState.Addresses) > 0 && len(s.Addresses) == 0 {

View file

@ -694,14 +694,34 @@ func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interf
return nil return nil
} }
// Information about RPC
type rpcInfo struct { type rpcInfo struct {
failfast bool failfast bool
preloaderInfo *compressorInfo
}
// Information about Preloader
// Responsible for storing codec, and compressors
// If stream (s) has context s.Context which stores rpcInfo that has non nil
// pointers to codec, and compressors, then we can use preparedMsg for Async message prep
// and reuse marshalled bytes
type compressorInfo struct {
codec baseCodec
cp Compressor
comp encoding.Compressor
} }
type rpcInfoContextKey struct{} type rpcInfoContextKey struct{}
func newContextWithRPCInfo(ctx context.Context, failfast bool) context.Context { func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context {
return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{failfast: failfast}) return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{
failfast: failfast,
preloaderInfo: &compressorInfo{
codec: codec,
cp: cp,
comp: comp,
},
})
} }
func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) { func rpcInfoFromContext(ctx context.Context) (s *rpcInfo, ok bool) {

View file

@ -42,6 +42,7 @@ import (
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@ -56,6 +57,8 @@ const (
defaultServerMaxSendMessageSize = math.MaxInt32 defaultServerMaxSendMessageSize = math.MaxInt32
) )
var statusOK = status.New(codes.OK, "")
type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error) type methodHandler func(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor UnaryServerInterceptor) (interface{}, error)
// MethodDesc represents an RPC service's method specification. // MethodDesc represents an RPC service's method specification.
@ -86,21 +89,19 @@ type service struct {
// Server is a gRPC server to serve RPC requests. // Server is a gRPC server to serve RPC requests.
type Server struct { type Server struct {
opts options opts serverOptions
mu sync.Mutex // guards following mu sync.Mutex // guards following
lis map[net.Listener]bool lis map[net.Listener]bool
conns map[io.Closer]bool conns map[transport.ServerTransport]bool
serve bool serve bool
drain bool drain bool
cv *sync.Cond // signaled when connections close for GracefulStop cv *sync.Cond // signaled when connections close for GracefulStop
m map[string]*service // service name -> service info m map[string]*service // service name -> service info
events trace.EventLog events trace.EventLog
quit chan struct{} quit *grpcsync.Event
done chan struct{} done *grpcsync.Event
quitOnce sync.Once
doneOnce sync.Once
channelzRemoveOnce sync.Once channelzRemoveOnce sync.Once
serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
@ -108,7 +109,7 @@ type Server struct {
czData *channelzData czData *channelzData
} }
type options struct { type serverOptions struct {
creds credentials.TransportCredentials creds credentials.TransportCredentials
codec baseCodec codec baseCodec
cp Compressor cp Compressor
@ -131,7 +132,7 @@ type options struct {
maxHeaderListSize *uint32 maxHeaderListSize *uint32
} }
var defaultServerOptions = options{ var defaultServerOptions = serverOptions{
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize, maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
maxSendMessageSize: defaultServerMaxSendMessageSize, maxSendMessageSize: defaultServerMaxSendMessageSize,
connectionTimeout: 120 * time.Second, connectionTimeout: 120 * time.Second,
@ -140,7 +141,33 @@ var defaultServerOptions = options{
} }
// A ServerOption sets options such as credentials, codec and keepalive parameters, etc. // A ServerOption sets options such as credentials, codec and keepalive parameters, etc.
type ServerOption func(*options) type ServerOption interface {
apply(*serverOptions)
}
// EmptyServerOption does not alter the server configuration. It can be embedded
// in another structure to build custom server options.
//
// This API is EXPERIMENTAL.
type EmptyServerOption struct{}
func (EmptyServerOption) apply(*serverOptions) {}
// funcServerOption wraps a function that modifies serverOptions into an
// implementation of the ServerOption interface.
type funcServerOption struct {
f func(*serverOptions)
}
func (fdo *funcServerOption) apply(do *serverOptions) {
fdo.f(do)
}
func newFuncServerOption(f func(*serverOptions)) *funcServerOption {
return &funcServerOption{
f: f,
}
}
// WriteBufferSize determines how much data can be batched before doing a write on the wire. // WriteBufferSize determines how much data can be batched before doing a write on the wire.
// The corresponding memory allocation for this buffer will be twice the size to keep syscalls low. // The corresponding memory allocation for this buffer will be twice the size to keep syscalls low.
@ -148,9 +175,9 @@ type ServerOption func(*options)
// Zero will disable the write buffer such that each write will be on underlying connection. // Zero will disable the write buffer such that each write will be on underlying connection.
// Note: A Send call may not directly translate to a write. // Note: A Send call may not directly translate to a write.
func WriteBufferSize(s int) ServerOption { func WriteBufferSize(s int) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.writeBufferSize = s o.writeBufferSize = s
} })
} }
// ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most // ReadBufferSize lets you set the size of read buffer, this determines how much data can be read at most
@ -159,25 +186,25 @@ func WriteBufferSize(s int) ServerOption {
// Zero will disable read buffer for a connection so data framer can access the underlying // Zero will disable read buffer for a connection so data framer can access the underlying
// conn directly. // conn directly.
func ReadBufferSize(s int) ServerOption { func ReadBufferSize(s int) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.readBufferSize = s o.readBufferSize = s
} })
} }
// InitialWindowSize returns a ServerOption that sets window size for stream. // InitialWindowSize returns a ServerOption that sets window size for stream.
// The lower bound for window size is 64K and any value smaller than that will be ignored. // The lower bound for window size is 64K and any value smaller than that will be ignored.
func InitialWindowSize(s int32) ServerOption { func InitialWindowSize(s int32) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.initialWindowSize = s o.initialWindowSize = s
} })
} }
// InitialConnWindowSize returns a ServerOption that sets window size for a connection. // InitialConnWindowSize returns a ServerOption that sets window size for a connection.
// The lower bound for window size is 64K and any value smaller than that will be ignored. // The lower bound for window size is 64K and any value smaller than that will be ignored.
func InitialConnWindowSize(s int32) ServerOption { func InitialConnWindowSize(s int32) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.initialConnWindowSize = s o.initialConnWindowSize = s
} })
} }
// KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server. // KeepaliveParams returns a ServerOption that sets keepalive and max-age parameters for the server.
@ -187,25 +214,25 @@ func KeepaliveParams(kp keepalive.ServerParameters) ServerOption {
kp.Time = time.Second kp.Time = time.Second
} }
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.keepaliveParams = kp o.keepaliveParams = kp
} })
} }
// KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server. // KeepaliveEnforcementPolicy returns a ServerOption that sets keepalive enforcement policy for the server.
func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption { func KeepaliveEnforcementPolicy(kep keepalive.EnforcementPolicy) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.keepalivePolicy = kep o.keepalivePolicy = kep
} })
} }
// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. // CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling.
// //
// This will override any lookups by content-subtype for Codecs registered with RegisterCodec. // This will override any lookups by content-subtype for Codecs registered with RegisterCodec.
func CustomCodec(codec Codec) ServerOption { func CustomCodec(codec Codec) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.codec = codec o.codec = codec
} })
} }
// RPCCompressor returns a ServerOption that sets a compressor for outbound // RPCCompressor returns a ServerOption that sets a compressor for outbound
@ -216,9 +243,9 @@ func CustomCodec(codec Codec) ServerOption {
// //
// Deprecated: use encoding.RegisterCompressor instead. // Deprecated: use encoding.RegisterCompressor instead.
func RPCCompressor(cp Compressor) ServerOption { func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.cp = cp o.cp = cp
} })
} }
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound // RPCDecompressor returns a ServerOption that sets a decompressor for inbound
@ -227,9 +254,9 @@ func RPCCompressor(cp Compressor) ServerOption {
// //
// Deprecated: use encoding.RegisterCompressor instead. // Deprecated: use encoding.RegisterCompressor instead.
func RPCDecompressor(dc Decompressor) ServerOption { func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.dc = dc o.dc = dc
} })
} }
// MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive. // MaxMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
@ -243,73 +270,73 @@ func MaxMsgSize(m int) ServerOption {
// MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive. // MaxRecvMsgSize returns a ServerOption to set the max message size in bytes the server can receive.
// If this is not set, gRPC uses the default 4MB. // If this is not set, gRPC uses the default 4MB.
func MaxRecvMsgSize(m int) ServerOption { func MaxRecvMsgSize(m int) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.maxReceiveMessageSize = m o.maxReceiveMessageSize = m
} })
} }
// MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send. // MaxSendMsgSize returns a ServerOption to set the max message size in bytes the server can send.
// If this is not set, gRPC uses the default `math.MaxInt32`. // If this is not set, gRPC uses the default `math.MaxInt32`.
func MaxSendMsgSize(m int) ServerOption { func MaxSendMsgSize(m int) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.maxSendMessageSize = m o.maxSendMessageSize = m
} })
} }
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number // MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport. // of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption { func MaxConcurrentStreams(n uint32) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.maxConcurrentStreams = n o.maxConcurrentStreams = n
} })
} }
// Creds returns a ServerOption that sets credentials for server connections. // Creds returns a ServerOption that sets credentials for server connections.
func Creds(c credentials.TransportCredentials) ServerOption { func Creds(c credentials.TransportCredentials) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.creds = c o.creds = c
} })
} }
// UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the // UnaryInterceptor returns a ServerOption that sets the UnaryServerInterceptor for the
// server. Only one unary interceptor can be installed. The construction of multiple // server. Only one unary interceptor can be installed. The construction of multiple
// interceptors (e.g., chaining) can be implemented at the caller. // interceptors (e.g., chaining) can be implemented at the caller.
func UnaryInterceptor(i UnaryServerInterceptor) ServerOption { func UnaryInterceptor(i UnaryServerInterceptor) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
if o.unaryInt != nil { if o.unaryInt != nil {
panic("The unary server interceptor was already set and may not be reset.") panic("The unary server interceptor was already set and may not be reset.")
} }
o.unaryInt = i o.unaryInt = i
} })
} }
// StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the // StreamInterceptor returns a ServerOption that sets the StreamServerInterceptor for the
// server. Only one stream interceptor can be installed. // server. Only one stream interceptor can be installed.
func StreamInterceptor(i StreamServerInterceptor) ServerOption { func StreamInterceptor(i StreamServerInterceptor) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
if o.streamInt != nil { if o.streamInt != nil {
panic("The stream server interceptor was already set and may not be reset.") panic("The stream server interceptor was already set and may not be reset.")
} }
o.streamInt = i o.streamInt = i
} })
} }
// InTapHandle returns a ServerOption that sets the tap handle for all the server // InTapHandle returns a ServerOption that sets the tap handle for all the server
// transport to be created. Only one can be installed. // transport to be created. Only one can be installed.
func InTapHandle(h tap.ServerInHandle) ServerOption { func InTapHandle(h tap.ServerInHandle) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
if o.inTapHandle != nil { if o.inTapHandle != nil {
panic("The tap handle was already set and may not be reset.") panic("The tap handle was already set and may not be reset.")
} }
o.inTapHandle = h o.inTapHandle = h
} })
} }
// StatsHandler returns a ServerOption that sets the stats handler for the server. // StatsHandler returns a ServerOption that sets the stats handler for the server.
func StatsHandler(h stats.Handler) ServerOption { func StatsHandler(h stats.Handler) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.statsHandler = h o.statsHandler = h
} })
} }
// UnknownServiceHandler returns a ServerOption that allows for adding a custom // UnknownServiceHandler returns a ServerOption that allows for adding a custom
@ -319,7 +346,7 @@ func StatsHandler(h stats.Handler) ServerOption {
// The handling function has full access to the Context of the request and the // The handling function has full access to the Context of the request and the
// stream, and the invocation bypasses interceptors. // stream, and the invocation bypasses interceptors.
func UnknownServiceHandler(streamHandler StreamHandler) ServerOption { func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.unknownStreamDesc = &StreamDesc{ o.unknownStreamDesc = &StreamDesc{
StreamName: "unknown_service_handler", StreamName: "unknown_service_handler",
Handler: streamHandler, Handler: streamHandler,
@ -327,7 +354,7 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
ClientStreams: true, ClientStreams: true,
ServerStreams: true, ServerStreams: true,
} }
} })
} }
// ConnectionTimeout returns a ServerOption that sets the timeout for // ConnectionTimeout returns a ServerOption that sets the timeout for
@ -337,17 +364,17 @@ func UnknownServiceHandler(streamHandler StreamHandler) ServerOption {
// //
// This API is EXPERIMENTAL. // This API is EXPERIMENTAL.
func ConnectionTimeout(d time.Duration) ServerOption { func ConnectionTimeout(d time.Duration) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.connectionTimeout = d o.connectionTimeout = d
} })
} }
// MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size // MaxHeaderListSize returns a ServerOption that sets the max (uncompressed) size
// of header list that the server is prepared to accept. // of header list that the server is prepared to accept.
func MaxHeaderListSize(s uint32) ServerOption { func MaxHeaderListSize(s uint32) ServerOption {
return func(o *options) { return newFuncServerOption(func(o *serverOptions) {
o.maxHeaderListSize = &s o.maxHeaderListSize = &s
} })
} }
// NewServer creates a gRPC server which has no service registered and has not // NewServer creates a gRPC server which has no service registered and has not
@ -355,15 +382,15 @@ func MaxHeaderListSize(s uint32) ServerOption {
func NewServer(opt ...ServerOption) *Server { func NewServer(opt ...ServerOption) *Server {
opts := defaultServerOptions opts := defaultServerOptions
for _, o := range opt { for _, o := range opt {
o(&opts) o.apply(&opts)
} }
s := &Server{ s := &Server{
lis: make(map[net.Listener]bool), lis: make(map[net.Listener]bool),
opts: opts, opts: opts,
conns: make(map[io.Closer]bool), conns: make(map[transport.ServerTransport]bool),
m: make(map[string]*service), m: make(map[string]*service),
quit: make(chan struct{}), quit: grpcsync.NewEvent(),
done: make(chan struct{}), done: grpcsync.NewEvent(),
czData: new(channelzData), czData: new(channelzData),
} }
s.cv = sync.NewCond(&s.mu) s.cv = sync.NewCond(&s.mu)
@ -530,11 +557,9 @@ func (s *Server) Serve(lis net.Listener) error {
s.serveWG.Add(1) s.serveWG.Add(1)
defer func() { defer func() {
s.serveWG.Done() s.serveWG.Done()
select { if s.quit.HasFired() {
// Stop or GracefulStop called; block until done and return nil. // Stop or GracefulStop called; block until done and return nil.
case <-s.quit: <-s.done.Done()
<-s.done
default:
} }
}() }()
@ -577,7 +602,7 @@ func (s *Server) Serve(lis net.Listener) error {
timer := time.NewTimer(tempDelay) timer := time.NewTimer(tempDelay)
select { select {
case <-timer.C: case <-timer.C:
case <-s.quit: case <-s.quit.Done():
timer.Stop() timer.Stop()
return nil return nil
} }
@ -587,10 +612,8 @@ func (s *Server) Serve(lis net.Listener) error {
s.printf("done serving; Accept = %v", err) s.printf("done serving; Accept = %v", err)
s.mu.Unlock() s.mu.Unlock()
select { if s.quit.HasFired() {
case <-s.quit:
return nil return nil
default:
} }
return err return err
} }
@ -611,6 +634,10 @@ func (s *Server) Serve(lis net.Listener) error {
// handleRawConn forks a goroutine to handle a just-accepted connection that // handleRawConn forks a goroutine to handle a just-accepted connection that
// has not had any I/O performed on it yet. // has not had any I/O performed on it yet.
func (s *Server) handleRawConn(rawConn net.Conn) { func (s *Server) handleRawConn(rawConn net.Conn) {
if s.quit.HasFired() {
rawConn.Close()
return
}
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
conn, authInfo, err := s.useTransportAuthenticator(rawConn) conn, authInfo, err := s.useTransportAuthenticator(rawConn)
if err != nil { if err != nil {
@ -627,14 +654,6 @@ func (s *Server) handleRawConn(rawConn net.Conn) {
return return
} }
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
conn.Close()
return
}
s.mu.Unlock()
// Finish handshaking (HTTP2) // Finish handshaking (HTTP2)
st := s.newHTTP2Transport(conn, authInfo) st := s.newHTTP2Transport(conn, authInfo)
if st == nil { if st == nil {
@ -742,6 +761,9 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. // traceInfo returns a traceInfo and associates it with stream, if tracing is enabled.
// If tracing is not enabled, it returns nil. // If tracing is not enabled, it returns nil.
func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) {
if !EnableTracing {
return nil
}
tr, ok := trace.FromContext(stream.Context()) tr, ok := trace.FromContext(stream.Context())
if !ok { if !ok {
return nil return nil
@ -760,27 +782,27 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea
return trInfo return trInfo
} }
func (s *Server) addConn(c io.Closer) bool { func (s *Server) addConn(st transport.ServerTransport) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns == nil { if s.conns == nil {
c.Close() st.Close()
return false return false
} }
if s.drain { if s.drain {
// Transport added after we drained our existing conns: drain it // Transport added after we drained our existing conns: drain it
// immediately. // immediately.
c.(transport.ServerTransport).Drain() st.Drain()
} }
s.conns[c] = true s.conns[st] = true
return true return true
} }
func (s *Server) removeConn(c io.Closer) { func (s *Server) removeConn(st transport.ServerTransport) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.conns != nil { if s.conns != nil {
delete(s.conns, c) delete(s.conns, st)
s.cv.Broadcast() s.cv.Broadcast()
} }
} }
@ -952,10 +974,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
} }
if sh != nil { if sh != nil {
sh.HandleRPC(stream.Context(), &stats.InPayload{ sh.HandleRPC(stream.Context(), &stats.InPayload{
RecvTime: time.Now(), RecvTime: time.Now(),
Payload: v, Payload: v,
Data: d, WireLength: payInfo.wireLength,
Length: len(d), Data: d,
Length: len(d),
}) })
} }
if binlog != nil { if binlog != nil {
@ -1051,7 +1074,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// TODO: Should we be logging if writing status failed here, like above? // TODO: Should we be logging if writing status failed here, like above?
// Should the logging be in WriteStatus? Should we ignore the WriteStatus // Should the logging be in WriteStatus? Should we ignore the WriteStatus
// error or allow the stats handler to see it? // error or allow the stats handler to see it?
err = t.WriteStatus(stream, status.New(codes.OK, "")) err = t.WriteStatus(stream, statusOK)
if binlog != nil { if binlog != nil {
binlog.Log(&binarylog.ServerTrailer{ binlog.Log(&binarylog.ServerTrailer{
Trailer: stream.Trailer(), Trailer: stream.Trailer(),
@ -1209,7 +1232,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ss.trInfo.tr.LazyLog(stringer("OK"), false) ss.trInfo.tr.LazyLog(stringer("OK"), false)
ss.mu.Unlock() ss.mu.Unlock()
} }
err = t.WriteStatus(ss.s, status.New(codes.OK, "")) err = t.WriteStatus(ss.s, statusOK)
if ss.binlog != nil { if ss.binlog != nil {
ss.binlog.Log(&binarylog.ServerTrailer{ ss.binlog.Log(&binarylog.ServerTrailer{
Trailer: ss.s.Trailer(), Trailer: ss.s.Trailer(),
@ -1326,15 +1349,11 @@ func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream
// pending RPCs on the client side will get notified by connection // pending RPCs on the client side will get notified by connection
// errors. // errors.
func (s *Server) Stop() { func (s *Server) Stop() {
s.quitOnce.Do(func() { s.quit.Fire()
close(s.quit)
})
defer func() { defer func() {
s.serveWG.Wait() s.serveWG.Wait()
s.doneOnce.Do(func() { s.done.Fire()
close(s.done)
})
}() }()
s.channelzRemoveOnce.Do(func() { s.channelzRemoveOnce.Do(func() {
@ -1371,15 +1390,8 @@ func (s *Server) Stop() {
// accepting new connections and RPCs and blocks until all the pending RPCs are // accepting new connections and RPCs and blocks until all the pending RPCs are
// finished. // finished.
func (s *Server) GracefulStop() { func (s *Server) GracefulStop() {
s.quitOnce.Do(func() { s.quit.Fire()
close(s.quit) defer s.done.Fire()
})
defer func() {
s.doneOnce.Do(func() {
close(s.done)
})
}()
s.channelzRemoveOnce.Do(func() { s.channelzRemoveOnce.Do(func() {
if channelz.IsOn() { if channelz.IsOn() {
@ -1397,8 +1409,8 @@ func (s *Server) GracefulStop() {
} }
s.lis = nil s.lis = nil
if !s.drain { if !s.drain {
for c := range s.conns { for st := range s.conns {
c.(transport.ServerTransport).Drain() st.Drain()
} }
s.drain = true s.drain = true
} }

View file

@ -25,8 +25,11 @@ import (
"strings" "strings"
"time" "time"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/serviceconfig"
) )
const maxInt = int(^uint(0) >> 1) const maxInt = int(^uint(0) >> 1)
@ -61,6 +64,11 @@ type MethodConfig struct {
retryPolicy *retryPolicy retryPolicy *retryPolicy
} }
type lbConfig struct {
name string
cfg serviceconfig.LoadBalancingConfig
}
// ServiceConfig is provided by the service provider and contains parameters for how // ServiceConfig is provided by the service provider and contains parameters for how
// clients that connect to the service should behave. // clients that connect to the service should behave.
// //
@ -68,10 +76,18 @@ type MethodConfig struct {
// through name resolver, as specified here // through name resolver, as specified here
// https://github.com/grpc/grpc/blob/master/doc/service_config.md // https://github.com/grpc/grpc/blob/master/doc/service_config.md
type ServiceConfig struct { type ServiceConfig struct {
// LB is the load balancer the service providers recommends. The balancer specified serviceconfig.Config
// via grpc.WithBalancer will override this.
// LB is the load balancer the service providers recommends. The balancer
// specified via grpc.WithBalancer will override this. This is deprecated;
// lbConfigs is preferred. If lbConfig and LB are both present, lbConfig
// will be used.
LB *string LB *string
// lbConfig is the service config's load balancing configuration. If
// lbConfig and LB are both present, lbConfig will be used.
lbConfig *lbConfig
// Methods contains a map for the methods in this service. If there is an // Methods contains a map for the methods in this service. If there is an
// exact match for a method (i.e. /service/method) in the map, use the // exact match for a method (i.e. /service/method) in the map, use the
// corresponding MethodConfig. If there's no exact match, look for the // corresponding MethodConfig. If there's no exact match, look for the
@ -233,15 +249,27 @@ type jsonMC struct {
RetryPolicy *jsonRetryPolicy RetryPolicy *jsonRetryPolicy
} }
type loadBalancingConfig map[string]json.RawMessage
// TODO(lyuxuan): delete this struct after cleaning up old service config implementation. // TODO(lyuxuan): delete this struct after cleaning up old service config implementation.
type jsonSC struct { type jsonSC struct {
LoadBalancingPolicy *string LoadBalancingPolicy *string
LoadBalancingConfig *[]loadBalancingConfig
MethodConfig *[]jsonMC MethodConfig *[]jsonMC
RetryThrottling *retryThrottlingPolicy RetryThrottling *retryThrottlingPolicy
HealthCheckConfig *healthCheckConfig HealthCheckConfig *healthCheckConfig
} }
func init() {
internal.ParseServiceConfig = func(sc string) (interface{}, error) {
return parseServiceConfig(sc)
}
}
func parseServiceConfig(js string) (*ServiceConfig, error) { func parseServiceConfig(js string) (*ServiceConfig, error) {
if len(js) == 0 {
return nil, fmt.Errorf("no JSON service config provided")
}
var rsc jsonSC var rsc jsonSC
err := json.Unmarshal([]byte(js), &rsc) err := json.Unmarshal([]byte(js), &rsc)
if err != nil { if err != nil {
@ -255,10 +283,38 @@ func parseServiceConfig(js string) (*ServiceConfig, error) {
healthCheckConfig: rsc.HealthCheckConfig, healthCheckConfig: rsc.HealthCheckConfig,
rawJSONString: js, rawJSONString: js,
} }
if rsc.LoadBalancingConfig != nil {
for i, lbcfg := range *rsc.LoadBalancingConfig {
if len(lbcfg) != 1 {
err := fmt.Errorf("invalid loadBalancingConfig: entry %v does not contain exactly 1 policy/config pair: %q", i, lbcfg)
grpclog.Warningf(err.Error())
return nil, err
}
var name string
var jsonCfg json.RawMessage
for name, jsonCfg = range lbcfg {
}
builder := balancer.Get(name)
if builder == nil {
continue
}
sc.lbConfig = &lbConfig{name: name}
if parser, ok := builder.(balancer.ConfigParser); ok {
var err error
sc.lbConfig.cfg, err = parser.ParseConfig(jsonCfg)
if err != nil {
return nil, fmt.Errorf("error parsing loadBalancingConfig for policy %q: %v", name, err)
}
} else if string(jsonCfg) != "{}" {
grpclog.Warningf("non-empty balancer configuration %q, but balancer does not implement ParseConfig", string(jsonCfg))
}
break
}
}
if rsc.MethodConfig == nil { if rsc.MethodConfig == nil {
return &sc, nil return &sc, nil
} }
for _, m := range *rsc.MethodConfig { for _, m := range *rsc.MethodConfig {
if m.Name == nil { if m.Name == nil {
continue continue
@ -299,11 +355,11 @@ func parseServiceConfig(js string) (*ServiceConfig, error) {
} }
if sc.retryThrottling != nil { if sc.retryThrottling != nil {
if sc.retryThrottling.MaxTokens <= 0 || if mt := sc.retryThrottling.MaxTokens; mt <= 0 || mt > 1000 {
sc.retryThrottling.MaxTokens > 1000 || return nil, fmt.Errorf("invalid retry throttling config: maxTokens (%v) out of range (0, 1000]", mt)
sc.retryThrottling.TokenRatio <= 0 { }
// Illegal throttling config; disable throttling. if tr := sc.retryThrottling.TokenRatio; tr <= 0 {
sc.retryThrottling = nil return nil, fmt.Errorf("invalid retry throttling config: tokenRatio (%v) may not be negative", tr)
} }
} }
return &sc, nil return &sc, nil

View file

@ -0,0 +1,48 @@
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package serviceconfig defines types and methods for operating on gRPC
// service configs.
//
// This package is EXPERIMENTAL.
package serviceconfig
import (
"google.golang.org/grpc/internal"
)
// Config represents an opaque data structure holding a service config.
type Config interface {
isConfig()
}
// LoadBalancingConfig represents an opaque data structure holding a load
// balancer config.
type LoadBalancingConfig interface {
isLoadBalancingConfig()
}
// Parse parses the JSON service config provided into an internal form or
// returns an error if the config is invalid.
func Parse(ServiceConfigJSON string) (Config, error) {
c, err := internal.ParseServiceConfig(ServiceConfigJSON)
if err != nil {
return nil, err
}
return c.(Config), err
}

View file

@ -36,8 +36,15 @@ import (
"github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes"
spb "google.golang.org/genproto/googleapis/rpc/status" spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/internal"
) )
func init() {
internal.StatusRawProto = statusRawProto
}
func statusRawProto(s *Status) *spb.Status { return s.s }
// statusError is an alias of a status proto. It implements error and Status, // statusError is an alias of a status proto. It implements error and Status,
// and a nil statusError should never be returned by this package. // and a nil statusError should never be returned by this package.
type statusError spb.Status type statusError spb.Status
@ -51,6 +58,17 @@ func (se *statusError) GRPCStatus() *Status {
return &Status{s: (*spb.Status)(se)} return &Status{s: (*spb.Status)(se)}
} }
// Is implements future error.Is functionality.
// A statusError is equivalent if the code and message are identical.
func (se *statusError) Is(target error) bool {
tse, ok := target.(*statusError)
if !ok {
return false
}
return proto.Equal((*spb.Status)(se), (*spb.Status)(tse))
}
// Status represents an RPC status code, message, and details. It is immutable // Status represents an RPC status code, message, and details. It is immutable
// and should be created with New, Newf, or FromProto. // and should be created with New, Newf, or FromProto.
type Status struct { type Status struct {
@ -125,7 +143,7 @@ func FromProto(s *spb.Status) *Status {
// Status is returned with codes.Unknown and the original error message. // Status is returned with codes.Unknown and the original error message.
func FromError(err error) (s *Status, ok bool) { func FromError(err error) (s *Status, ok bool) {
if err == nil { if err == nil {
return &Status{s: &spb.Status{Code: int32(codes.OK)}}, true return nil, true
} }
if se, ok := err.(interface { if se, ok := err.(interface {
GRPCStatus() *Status GRPCStatus() *Status
@ -199,7 +217,7 @@ func Code(err error) codes.Code {
func FromContextError(err error) *Status { func FromContextError(err error) *Status {
switch err { switch err {
case nil: case nil:
return New(codes.OK, "") return nil
case context.DeadlineExceeded: case context.DeadlineExceeded:
return New(codes.DeadlineExceeded, err.Error()) return New(codes.DeadlineExceeded, err.Error())
case context.Canceled: case context.Canceled:

View file

@ -30,7 +30,6 @@ import (
"golang.org/x/net/trace" "golang.org/x/net/trace"
"google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding"
"google.golang.org/grpc/grpclog" "google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/balancerload" "google.golang.org/grpc/internal/balancerload"
@ -245,7 +244,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
trInfo.tr.LazyLog(&trInfo.firstLine, false) trInfo.tr.LazyLog(&trInfo.firstLine, false)
ctx = trace.NewContext(ctx, trInfo.tr) ctx = trace.NewContext(ctx, trInfo.tr)
} }
ctx = newContextWithRPCInfo(ctx, c.failFast) ctx = newContextWithRPCInfo(ctx, c.failFast, c.codec, cp, comp)
sh := cc.dopts.copts.StatsHandler sh := cc.dopts.copts.StatsHandler
var beginTime time.Time var beginTime time.Time
if sh != nil { if sh != nil {
@ -328,13 +327,23 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
return cs, nil return cs, nil
} }
func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) error { // newAttemptLocked creates a new attempt with a transport.
cs.attempt = &csAttempt{ // If it succeeds, then it replaces clientStream's attempt with this new attempt.
func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) (retErr error) {
newAttempt := &csAttempt{
cs: cs, cs: cs,
dc: cs.cc.dopts.dc, dc: cs.cc.dopts.dc,
statsHandler: sh, statsHandler: sh,
trInfo: trInfo, trInfo: trInfo,
} }
defer func() {
if retErr != nil {
// This attempt is not set in the clientStream, so it's finish won't
// be called. Call it here for stats and trace in case they are not
// nil.
newAttempt.finish(retErr)
}
}()
if err := cs.ctx.Err(); err != nil { if err := cs.ctx.Err(); err != nil {
return toRPCErr(err) return toRPCErr(err)
@ -346,8 +355,9 @@ func (cs *clientStream) newAttemptLocked(sh stats.Handler, trInfo *traceInfo) er
if trInfo != nil { if trInfo != nil {
trInfo.firstLine.SetRemoteAddr(t.RemoteAddr()) trInfo.firstLine.SetRemoteAddr(t.RemoteAddr())
} }
cs.attempt.t = t newAttempt.t = t
cs.attempt.done = done newAttempt.done = done
cs.attempt = newAttempt
return nil return nil
} }
@ -396,11 +406,18 @@ type clientStream struct {
serverHeaderBinlogged bool serverHeaderBinlogged bool
mu sync.Mutex mu sync.Mutex
firstAttempt bool // if true, transparent retry is valid firstAttempt bool // if true, transparent retry is valid
numRetries int // exclusive of transparent retry attempt(s) numRetries int // exclusive of transparent retry attempt(s)
numRetriesSincePushback int // retries since pushback; to reset backoff numRetriesSincePushback int // retries since pushback; to reset backoff
finished bool // TODO: replace with atomic cmpxchg or sync.Once? finished bool // TODO: replace with atomic cmpxchg or sync.Once?
attempt *csAttempt // the active client stream attempt // attempt is the active client stream attempt.
// The only place where it is written is the newAttemptLocked method and this method never writes nil.
// So, attempt can be nil only inside newClientStream function when clientStream is first created.
// One of the first things done after clientStream's creation, is to call newAttemptLocked which either
// assigns a non nil value to the attempt or returns an error. If an error is returned from newAttemptLocked,
// then newClientStream calls finish on the clientStream and returns. So, finish method is the only
// place where we need to check if the attempt is nil.
attempt *csAttempt
// TODO(hedging): hedging will have multiple attempts simultaneously. // TODO(hedging): hedging will have multiple attempts simultaneously.
committed bool // active attempt committed for retry? committed bool // active attempt committed for retry?
buffer []func(a *csAttempt) error // operations to replay on retry buffer []func(a *csAttempt) error // operations to replay on retry
@ -458,8 +475,8 @@ func (cs *clientStream) shouldRetry(err error) error {
if cs.attempt.s != nil { if cs.attempt.s != nil {
<-cs.attempt.s.Done() <-cs.attempt.s.Done()
} }
if cs.firstAttempt && !cs.callInfo.failFast && (cs.attempt.s == nil || cs.attempt.s.Unprocessed()) { if cs.firstAttempt && (cs.attempt.s == nil || cs.attempt.s.Unprocessed()) {
// First attempt, wait-for-ready, stream unprocessed: transparently retry. // First attempt, stream unprocessed: transparently retry.
cs.firstAttempt = false cs.firstAttempt = false
return nil return nil
} }
@ -677,15 +694,13 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if !cs.desc.ClientStreams { if !cs.desc.ClientStreams {
cs.sentLast = true cs.sentLast = true
} }
data, err := encode(cs.codec, m)
// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
if err != nil { if err != nil {
return err return err
} }
compData, err := compress(data, cs.cp, cs.comp)
if err != nil {
return err
}
hdr, payload := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.callInfo.maxSendMessageSize { if len(payload) > *cs.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
@ -808,11 +823,11 @@ func (cs *clientStream) finish(err error) {
} }
if cs.attempt != nil { if cs.attempt != nil {
cs.attempt.finish(err) cs.attempt.finish(err)
} // after functions all rely upon having a stream.
// after functions all rely upon having a stream. if cs.attempt.s != nil {
if cs.attempt.s != nil { for _, o := range cs.opts {
for _, o := range cs.opts { o.after(cs.callInfo)
o.after(cs.callInfo) }
} }
} }
cs.cancel() cs.cancel()
@ -967,19 +982,18 @@ func (a *csAttempt) finish(err error) {
a.mu.Unlock() a.mu.Unlock()
} }
func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, method string, t transport.ClientTransport, opts ...CallOption) (_ ClientStream, err error) { // newClientStream creates a ClientStream with the specified transport, on the
ac.mu.Lock() // given addrConn.
if ac.transport != t { //
ac.mu.Unlock() // It's expected that the given transport is either the same one in addrConn, or
return nil, status.Error(codes.Canceled, "the provided transport is no longer valid to use") // is already closed. To avoid race, transport is specified separately, instead
} // of using ac.transpot.
// transition to CONNECTING state when an attempt starts //
if ac.state != connectivity.Connecting { // Main difference between this and ClientConn.NewStream:
ac.updateConnectivityState(connectivity.Connecting) // - no retry
ac.cc.handleSubConnStateChange(ac.acbw, ac.state) // - no service config (or wait for service config)
} // - no tracing or stats
ac.mu.Unlock() func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method string, t transport.ClientTransport, ac *addrConn, opts ...CallOption) (_ ClientStream, err error) {
if t == nil { if t == nil {
// TODO: return RPC error here? // TODO: return RPC error here?
return nil, errors.New("transport provided is nil") return nil, errors.New("transport provided is nil")
@ -987,14 +1001,6 @@ func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, metho
// defaultCallInfo contains unnecessary info(i.e. failfast, maxRetryRPCBufferSize), so we just initialize an empty struct. // defaultCallInfo contains unnecessary info(i.e. failfast, maxRetryRPCBufferSize), so we just initialize an empty struct.
c := &callInfo{} c := &callInfo{}
for _, o := range opts {
if err := o.before(c); err != nil {
return nil, toRPCErr(err)
}
}
c.maxReceiveMessageSize = getMaxSize(nil, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
c.maxSendMessageSize = getMaxSize(nil, c.maxSendMessageSize, defaultServerMaxSendMessageSize)
// Possible context leak: // Possible context leak:
// The cancel function for the child context we create will only be called // The cancel function for the child context we create will only be called
// when RecvMsg returns a non-nil error, if the ClientConn is closed, or if // when RecvMsg returns a non-nil error, if the ClientConn is closed, or if
@ -1007,6 +1013,13 @@ func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, metho
} }
}() }()
for _, o := range opts {
if err := o.before(c); err != nil {
return nil, toRPCErr(err)
}
}
c.maxReceiveMessageSize = getMaxSize(nil, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize)
c.maxSendMessageSize = getMaxSize(nil, c.maxSendMessageSize, defaultServerMaxSendMessageSize)
if err := setCallInfoCodec(c); err != nil { if err := setCallInfoCodec(c); err != nil {
return nil, err return nil, err
} }
@ -1039,6 +1052,7 @@ func (ac *addrConn) newClientStream(ctx context.Context, desc *StreamDesc, metho
callHdr.Creds = c.creds callHdr.Creds = c.creds
} }
// Use a special addrConnStream to avoid retry.
as := &addrConnStream{ as := &addrConnStream{
callHdr: callHdr, callHdr: callHdr,
ac: ac, ac: ac,
@ -1150,15 +1164,13 @@ func (as *addrConnStream) SendMsg(m interface{}) (err error) {
if !as.desc.ClientStreams { if !as.desc.ClientStreams {
as.sentLast = true as.sentLast = true
} }
data, err := encode(as.codec, m)
// load hdr, payload, data
hdr, payld, _, err := prepareMsg(m, as.codec, as.cp, as.comp)
if err != nil { if err != nil {
return err return err
} }
compData, err := compress(data, as.cp, as.comp)
if err != nil {
return err
}
hdr, payld := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payld) > *as.callInfo.maxSendMessageSize { if len(payld) > *as.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payld), *as.callInfo.maxSendMessageSize)
@ -1395,15 +1407,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
ss.t.IncrMsgSent() ss.t.IncrMsgSent()
} }
}() }()
data, err := encode(ss.codec, m)
// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil { if err != nil {
return err return err
} }
compData, err := compress(data, ss.cp, ss.comp)
if err != nil {
return err
}
hdr, payload := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead? // TODO(dfawley): should we be checking len(data) instead?
if len(payload) > ss.maxSendMessageSize { if len(payload) > ss.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize) return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), ss.maxSendMessageSize)
@ -1496,3 +1506,24 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
func MethodFromServerStream(stream ServerStream) (string, bool) { func MethodFromServerStream(stream ServerStream) (string, bool) {
return Method(stream.Context()) return Method(stream.Context())
} }
// prepareMsg returns the hdr, payload and data
// using the compressors passed or using the
// passed preparedmsg
func prepareMsg(m interface{}, codec baseCodec, cp Compressor, comp encoding.Compressor) (hdr, payload, data []byte, err error) {
if preparedMsg, ok := m.(*PreparedMsg); ok {
return preparedMsg.hdr, preparedMsg.payload, preparedMsg.encodedData, nil
}
// The input interface is not a prepared msg.
// Marshal and Compress the data at this point
data, err = encode(codec, m)
if err != nil {
return nil, nil, nil, err
}
compData, err := compress(data, cp, comp)
if err != nil {
return nil, nil, nil, err
}
hdr, payload = msgHeader(data, compData)
return hdr, payload, data, nil
}

View file

@ -19,4 +19,4 @@
package grpc package grpc
// Version is the current grpc version. // Version is the current grpc version.
const Version = "1.20.1" const Version = "1.23.0"