mirror of https://github.com/grpc/grpc-go.git
credentials: plumb cancellation into ClientHandshake
This is a minor breaking change to `TransportCredentials`, however it should not be a problem in practice as not many users are using custom implementations. In particular, users of `NewTLS` will not be affected. This change also replaces the earlier `Timeout` and `Cancel` fields with a `context.Context`, which is plumbed all the way down from `grpc.Dial`, laying the ground work for a user-provided context. Also, support for Go 1.7 is added.
This commit is contained in:
parent
5a423e610f
commit
5c7ed938f9
|
@ -198,8 +198,11 @@ func WithTimeout(d time.Duration) DialOption {
|
|||
// WithDialer returns a DialOption that specifies a function to use for dialing network addresses.
|
||||
func WithDialer(f func(string, time.Duration) (net.Conn, error)) DialOption {
|
||||
return func(o *dialOptions) {
|
||||
o.copts.Dialer = func(addr string, timeout time.Duration, _ <-chan struct{}) (net.Conn, error) {
|
||||
return f(addr, timeout)
|
||||
o.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
return f(addr, deadline.Sub(time.Now()))
|
||||
}
|
||||
return f(addr, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -213,10 +216,12 @@ func WithUserAgent(s string) DialOption {
|
|||
|
||||
// Dial creates a client connection the given target.
|
||||
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
||||
ctx := context.Background()
|
||||
cc := &ClientConn{
|
||||
target: target,
|
||||
conns: make(map[Address]*addrConn),
|
||||
}
|
||||
cc.ctx, cc.cancel = context.WithCancel(ctx)
|
||||
for _, opt := range opts {
|
||||
opt(&cc.dopts)
|
||||
}
|
||||
|
@ -269,6 +274,9 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
|
|||
cc.Close()
|
||||
return nil, err
|
||||
}
|
||||
case <-cc.ctx.Done():
|
||||
cc.Close()
|
||||
return nil, cc.ctx.Err()
|
||||
case <-timeoutCh:
|
||||
cc.Close()
|
||||
return nil, ErrClientConnTimeout
|
||||
|
@ -319,6 +327,9 @@ func (s ConnectivityState) String() string {
|
|||
|
||||
// ClientConn represents a client connection to an RPC server.
|
||||
type ClientConn struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
target string
|
||||
authority string
|
||||
dopts dialOptions
|
||||
|
@ -371,8 +382,8 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err
|
|||
addr: addr,
|
||||
dopts: cc.dopts,
|
||||
}
|
||||
ac.ctx, ac.cancel = context.WithCancel(cc.ctx)
|
||||
ac.stateCV = sync.NewCond(&ac.mu)
|
||||
ac.dopts.copts.Cancel = make(chan struct{})
|
||||
if EnableTracing {
|
||||
ac.events = trace.NewEventLog("grpc.ClientConn", ac.addr.Addr)
|
||||
}
|
||||
|
@ -390,15 +401,15 @@ func (cc *ClientConn) resetAddrConn(addr Address, skipWait bool, tearDownErr err
|
|||
}
|
||||
}
|
||||
}
|
||||
// Insert ac into ac.cc.conns. This needs to be done before any getTransport(...) is called.
|
||||
ac.cc.mu.Lock()
|
||||
if ac.cc.conns == nil {
|
||||
ac.cc.mu.Unlock()
|
||||
// Track ac in cc. This needs to be done before any getTransport(...) is called.
|
||||
cc.mu.Lock()
|
||||
if cc.conns == nil {
|
||||
cc.mu.Unlock()
|
||||
return ErrClientConnClosing
|
||||
}
|
||||
stale := ac.cc.conns[ac.addr]
|
||||
ac.cc.conns[ac.addr] = ac
|
||||
ac.cc.mu.Unlock()
|
||||
stale := cc.conns[ac.addr]
|
||||
cc.conns[ac.addr] = ac
|
||||
cc.mu.Unlock()
|
||||
if stale != nil {
|
||||
// There is an addrConn alive on ac.addr already. This could be due to
|
||||
// 1) a buggy Balancer notifies duplicated Addresses;
|
||||
|
@ -473,6 +484,8 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
|
|||
|
||||
// Close tears down the ClientConn and all underlying connections.
|
||||
func (cc *ClientConn) Close() error {
|
||||
cc.cancel()
|
||||
|
||||
cc.mu.Lock()
|
||||
if cc.conns == nil {
|
||||
cc.mu.Unlock()
|
||||
|
@ -490,6 +503,9 @@ func (cc *ClientConn) Close() error {
|
|||
|
||||
// addrConn is a network connection to a given address.
|
||||
type addrConn struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
cc *ClientConn
|
||||
addr Address
|
||||
dopts dialOptions
|
||||
|
@ -579,14 +595,16 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
|||
t.Close()
|
||||
}
|
||||
sleepTime := ac.dopts.bs.backoff(retries)
|
||||
copts := ac.dopts.copts
|
||||
copts.Timeout = sleepTime
|
||||
if sleepTime < minConnectTimeout {
|
||||
copts.Timeout = minConnectTimeout
|
||||
timeout := minConnectTimeout
|
||||
if timeout < sleepTime {
|
||||
timeout = sleepTime
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
|
||||
connectTime := time.Now()
|
||||
newTransport, err := transport.NewClientTransport(ac.addr.Addr, copts)
|
||||
newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
|
||||
if err != nil {
|
||||
cancel()
|
||||
|
||||
ac.mu.Lock()
|
||||
if ac.state == Shutdown {
|
||||
// ac.tearDown(...) has been invoked.
|
||||
|
@ -601,14 +619,11 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
|
|||
ac.ready = nil
|
||||
}
|
||||
ac.mu.Unlock()
|
||||
sleepTime -= time.Since(connectTime)
|
||||
if sleepTime < 0 {
|
||||
sleepTime = 0
|
||||
}
|
||||
closeTransport = false
|
||||
select {
|
||||
case <-time.After(sleepTime):
|
||||
case <-ac.dopts.copts.Cancel:
|
||||
case <-time.After(sleepTime - time.Since(connectTime)):
|
||||
case <-ac.ctx.Done():
|
||||
return ac.ctx.Err()
|
||||
}
|
||||
retries++
|
||||
grpclog.Printf("grpc: addrConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, ac.addr)
|
||||
|
@ -643,9 +658,9 @@ func (ac *addrConn) transportMonitor() {
|
|||
t := ac.transport
|
||||
ac.mu.Unlock()
|
||||
select {
|
||||
// Cancel is needed to detect the teardown when
|
||||
// This is needed to detect the teardown when
|
||||
// the addrConn is idle (i.e., no RPC in flight).
|
||||
case <-ac.dopts.copts.Cancel:
|
||||
case <-ac.ctx.Done():
|
||||
select {
|
||||
case <-t.Error():
|
||||
t.Close()
|
||||
|
@ -668,7 +683,7 @@ func (ac *addrConn) transportMonitor() {
|
|||
return
|
||||
case <-t.Error():
|
||||
select {
|
||||
case <-ac.dopts.copts.Cancel:
|
||||
case <-ac.ctx.Done():
|
||||
t.Close()
|
||||
return
|
||||
case <-t.GoAway():
|
||||
|
@ -735,6 +750,8 @@ func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTr
|
|||
// tight loop.
|
||||
// tearDown doesn't remove ac from ac.cc.conns.
|
||||
func (ac *addrConn) tearDown(err error) {
|
||||
ac.cancel()
|
||||
|
||||
ac.mu.Lock()
|
||||
defer ac.mu.Unlock()
|
||||
if ac.down != nil {
|
||||
|
@ -764,8 +781,5 @@ func (ac *addrConn) tearDown(err error) {
|
|||
if ac.transport != nil && err != errConnDrain {
|
||||
ac.transport.Close()
|
||||
}
|
||||
if ac.dopts.copts.Cancel != nil {
|
||||
close(ac.dopts.copts.Cancel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
@ -93,11 +92,12 @@ type TransportCredentials interface {
|
|||
// ClientHandshake does the authentication handshake specified by the corresponding
|
||||
// authentication protocol on rawConn for clients. It returns the authenticated
|
||||
// connection and the corresponding auth information about the connection.
|
||||
ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error)
|
||||
// Implementations must use the provided context to implement timely cancellation.
|
||||
ClientHandshake(context.Context, string, net.Conn) (net.Conn, AuthInfo, error)
|
||||
// ServerHandshake does the authentication handshake for servers. It returns
|
||||
// the authenticated connection and the corresponding auth information about
|
||||
// the connection.
|
||||
ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||
ServerHandshake(net.Conn) (net.Conn, AuthInfo, error)
|
||||
// Info provides the ProtocolInfo of this TransportCredentials.
|
||||
Info() ProtocolInfo
|
||||
}
|
||||
|
@ -136,21 +136,7 @@ func (c *tlsCreds) RequireTransportSecurity() bool {
|
|||
return true
|
||||
}
|
||||
|
||||
type timeoutError struct{}
|
||||
|
||||
func (timeoutError) Error() string { return "credentials: Dial timed out" }
|
||||
func (timeoutError) Timeout() bool { return true }
|
||||
func (timeoutError) Temporary() bool { return true }
|
||||
|
||||
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) {
|
||||
// borrow some code from tls.DialWithDialer
|
||||
var errChannel chan error
|
||||
if timeout != 0 {
|
||||
errChannel = make(chan error, 2)
|
||||
time.AfterFunc(timeout, func() {
|
||||
errChannel <- timeoutError{}
|
||||
})
|
||||
}
|
||||
func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
|
||||
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
||||
cfg := cloneTLSConfig(c.config)
|
||||
if c.config.ServerName == "" {
|
||||
|
@ -161,17 +147,18 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
|
|||
cfg.ServerName = addr[:colonPos]
|
||||
}
|
||||
conn := tls.Client(rawConn, cfg)
|
||||
if timeout == 0 {
|
||||
err = conn.Handshake()
|
||||
} else {
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
err = <-errChannel
|
||||
}
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, nil, err
|
||||
errChannel := make(chan error, 1)
|
||||
go func() {
|
||||
errChannel <- conn.Handshake()
|
||||
}()
|
||||
select {
|
||||
case err := <-errChannel:
|
||||
if err != nil {
|
||||
rawConn.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil, nil, ctx.Err()
|
||||
}
|
||||
// TODO(zhaoq): Omit the auth info for client now. It is more for
|
||||
// information than anything else.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// +build go1.6
|
||||
// +build go1.6,!go1.7
|
||||
|
||||
/*
|
||||
* Copyright 2014, Google Inc.
|
||||
* Copyright 2016, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
|
@ -36,10 +36,11 @@ package transport
|
|||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// newDialer constructs a net.Dialer.
|
||||
func newDialer(timeout time.Duration, cancel <-chan struct{}) *net.Dialer {
|
||||
return &net.Dialer{Cancel: cancel, Timeout: timeout}
|
||||
// dialContext connects to the address on the named network.
|
||||
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
// +build go1.7
|
||||
|
||||
/*
|
||||
* Copyright 2016, Google Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
* * Neither the name of Google Inc. nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
*/
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// dialContext connects to the address on the named network.
|
||||
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return (&net.Dialer{}).DialContext(ctx, network, address)
|
||||
}
|
|
@ -107,31 +107,26 @@ type http2Client struct {
|
|||
prevGoAwayID uint32
|
||||
}
|
||||
|
||||
func dial(fn func(string, time.Duration, <-chan struct{}) (net.Conn, error), addr string, timeout time.Duration, cancel <-chan struct{}) (net.Conn, error) {
|
||||
func dial(fn func(context.Context, string) (net.Conn, error), ctx context.Context, addr string) (net.Conn, error) {
|
||||
if fn != nil {
|
||||
return fn(addr, timeout, cancel)
|
||||
return fn(ctx, addr)
|
||||
}
|
||||
return newDialer(timeout, cancel).Dial("tcp", addr)
|
||||
return dialContext(ctx, "tcp", addr)
|
||||
}
|
||||
|
||||
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
|
||||
// and starts to receive messages on it. Non-nil error returns if construction
|
||||
// fails.
|
||||
func newHTTP2Client(addr string, opts ConnectOptions) (_ ClientTransport, err error) {
|
||||
func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
|
||||
scheme := "http"
|
||||
startT := time.Now()
|
||||
timeout := opts.Timeout
|
||||
conn, connErr := dial(opts.Dialer, addr, timeout, opts.Cancel)
|
||||
conn, connErr := dial(opts.Dialer, ctx, addr)
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
}
|
||||
var authInfo credentials.AuthInfo
|
||||
if opts.TransportCredentials != nil {
|
||||
if creds := opts.TransportCredentials; creds != nil {
|
||||
scheme = "https"
|
||||
if timeout > 0 {
|
||||
timeout -= time.Since(startT)
|
||||
}
|
||||
conn, authInfo, connErr = opts.TransportCredentials.ClientHandshake(addr, conn, timeout)
|
||||
conn, authInfo, connErr = creds.ClientHandshake(ctx, addr, conn)
|
||||
}
|
||||
if connErr != nil {
|
||||
return nil, ConnectionErrorf("transport: %v", connErr)
|
||||
|
|
|
@ -37,9 +37,15 @@ package transport
|
|||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// newDialer constructs a net.Dialer.
|
||||
func newDialer(timeout time.Duration, _ <-chan struct{}) *net.Dialer {
|
||||
return &net.Dialer{Timeout: timeout}
|
||||
// dialContext connects to the address on the named network.
|
||||
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
dialer.Timeout = deadline.Sub(time.Now())
|
||||
}
|
||||
return dialer.Dial(network, address)
|
||||
}
|
||||
|
|
|
@ -44,7 +44,6 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/trace"
|
||||
|
@ -355,22 +354,18 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authI
|
|||
type ConnectOptions struct {
|
||||
// UserAgent is the application user agent.
|
||||
UserAgent string
|
||||
// Cancel is closed to indicate that dialing should be cancelled.
|
||||
Cancel chan struct{}
|
||||
// Dialer specifies how to dial a network address.
|
||||
Dialer func(string, time.Duration, <-chan struct{}) (net.Conn, error)
|
||||
Dialer func(context.Context, string) (net.Conn, error)
|
||||
// PerRPCCredentials stores the PerRPCCredentials required to issue RPCs.
|
||||
PerRPCCredentials []credentials.PerRPCCredentials
|
||||
// TransportCredentials stores the Authenticator required to setup a client connection.
|
||||
TransportCredentials credentials.TransportCredentials
|
||||
// Timeout specifies the timeout for dialing a ClientTransport.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// NewClientTransport establishes the transport with the required ConnectOptions
|
||||
// and returns it to the caller.
|
||||
func NewClientTransport(target string, opts ConnectOptions) (ClientTransport, error) {
|
||||
return newHTTP2Client(target, opts)
|
||||
func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
|
||||
return newHTTP2Client(ctx, target, opts)
|
||||
}
|
||||
|
||||
// Options provides additional hints and information for message
|
||||
|
|
|
@ -228,7 +228,7 @@ func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, Client
|
|||
ct ClientTransport
|
||||
connErr error
|
||||
)
|
||||
ct, connErr = NewClientTransport(addr, ConnectOptions{})
|
||||
ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{})
|
||||
if connErr != nil {
|
||||
t.Fatalf("failed to create transport: %v", connErr)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue