make downErr for Balancer down closure

This commit is contained in:
iamqizhao 2016-05-25 11:28:45 -07:00
parent 8eab9cb6bf
commit 9dc3da0633
4 changed files with 65 additions and 42 deletions

View File

@ -34,6 +34,7 @@
package grpc package grpc
import ( import (
"fmt"
"sync" "sync"
"golang.org/x/net/context" "golang.org/x/net/context"
@ -103,6 +104,26 @@ type Balancer interface {
Close() error Close() error
} }
// downErr implements net.Error. It is contructed by gRPC internals and passed to the down
// call of Balancer.
type downErr struct {
timeout bool
temporary bool
desc string
}
func (e downErr) Error() string { return e.desc }
func (e downErr) Timeout() bool { return e.timeout }
func (e downErr) Temporary() bool { return e.temporary }
func downErrorf(timeout, temporary bool, format string, a ...interface{}) downErr {
return downErr{
timeout: timeout,
temporary: temporary,
desc: fmt.Sprintf(format, a...),
}
}
// RoundRobin returns a Balancer that selects addresses round-robin. It starts to watch // RoundRobin returns a Balancer that selects addresses round-robin. It starts to watch
// the name resolution updates. // the name resolution updates.
func RoundRobin(r naming.Resolver) Balancer { func RoundRobin(r naming.Resolver) Balancer {

View File

@ -154,7 +154,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if err != nil { if err != nil {
// TODO(zhaoq): Probably revisit the error handling. // TODO(zhaoq): Probably revisit the error handling.
if err == ErrClientConnClosing { if err == ErrClientConnClosing {
return toRPCErr(err) return Errorf(codes.FailedPrecondition, "%v", err)
} }
if _, ok := err.(transport.StreamError); ok { if _, ok := err.(transport.StreamError); ok {
return toRPCErr(err) return toRPCErr(err)
@ -189,6 +189,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
put = nil
} }
if _, ok := err.(transport.ConnectionError); ok { if _, ok := err.(transport.ConnectionError); ok {
if c.failFast { if c.failFast {

View File

@ -51,26 +51,27 @@ import (
) )
var ( var (
// ErrNoTransportSecurity indicates that there is no transport security
// being set for ClientConn. Users should either set one or explicitly
// call WithInsecure DialOption to disable security.
ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
// ErrCredentialsMisuse indicates that users want to transmit security information
// (e.g., oauth2 token) which requires secure connection on an insecure
// connection.
ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
// ErrClientConnClosing indicates that the operation is illegal because // ErrClientConnClosing indicates that the operation is illegal because
// the ClientConn is closing. // the ClientConn is closing.
ErrClientConnClosing = Errorf(codes.FailedPrecondition, "grpc: the client connection is closing") ErrClientConnClosing = errors.New("grpc: the client connection is closing")
// ErrClientConnTimeout indicates that the connection could not be
// errNoTransportSecurity indicates that there is no transport security
// being set for ClientConn. Users should either set one or explicitly
// call WithInsecure DialOption to disable security.
errNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)")
// errCredentialsMisuse indicates that users want to transmit security information
// (e.g., oauth2 token) which requires secure connection on an insecure
// connection.
errCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)")
// errClientConnTimeout indicates that the connection could not be
// established or re-established within the specified timeout. // established or re-established within the specified timeout.
ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") errClientConnTimeout = errors.New("grpc: timed out trying to connect")
// ErrNetworkIP indicates that the connection is down due to some network I/O error. // errNetworkIP indicates that the connection is down due to some network I/O error.
ErrNetworkIO = errors.New("grpc: failed with network I/O error") errNetworkIO = errors.New("grpc: failed with network I/O error")
// ErrConnDrain indicates that the connection starts to be drained and does not accept any new RPCs. // errConnDrain indicates that the connection starts to be drained and does not accept any new RPCs.
ErrConnDrain = errors.New("grpc: the connection is drained") errConnDrain = errors.New("grpc: the connection is drained")
// ErrConnClosing // errConnClosing indicates that the connection is closing.
ErrConnClosing = errors.New("grpc: the addrConn is closing") errConnClosing = errors.New("grpc: the connection is closing")
// minimum time to give a connection to complete // minimum time to give a connection to complete
minConnectTimeout = 20 * time.Second minConnectTimeout = 20 * time.Second
) )
@ -337,7 +338,7 @@ func (cc *ClientConn) controller() {
} }
} }
for _, c := range del { for _, c := range del {
c.tearDown(ErrConnDrain) c.tearDown(errConnDrain)
} }
} }
} }
@ -360,12 +361,12 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
} }
if !ok { if !ok {
return ErrNoTransportSecurity return errNoTransportSecurity
} }
} else { } else {
for _, cd := range ac.dopts.copts.AuthOptions { for _, cd := range ac.dopts.copts.AuthOptions {
if cd.RequireTransportSecurity() { if cd.RequireTransportSecurity() {
return ErrCredentialsMisuse return errCredentialsMisuse
} }
} }
} }
@ -529,10 +530,10 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
ac.mu.Unlock() ac.mu.Unlock()
return ErrConnClosing return errConnClosing
} }
if ac.down != nil { if ac.down != nil {
ac.down(ErrNetworkIO) ac.down(downErrorf(false, true, "%v", errNetworkIO))
ac.down = nil ac.down = nil
} }
ac.state = Connecting ac.state = Connecting
@ -545,14 +546,14 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
// Adjust timeout for the current try. // Adjust timeout for the current try.
copts := ac.dopts.copts copts := ac.dopts.copts
if copts.Timeout < 0 { if copts.Timeout < 0 {
ac.tearDown(ErrClientConnTimeout) ac.tearDown(errClientConnTimeout)
return ErrClientConnTimeout return errClientConnTimeout
} }
if copts.Timeout > 0 { if copts.Timeout > 0 {
copts.Timeout -= time.Since(start) copts.Timeout -= time.Since(start)
if copts.Timeout <= 0 { if copts.Timeout <= 0 {
ac.tearDown(ErrClientConnTimeout) ac.tearDown(errClientConnTimeout)
return ErrClientConnTimeout return errClientConnTimeout
} }
} }
sleepTime := ac.dopts.bs.backoff(retries) sleepTime := ac.dopts.bs.backoff(retries)
@ -570,7 +571,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
if ac.state == Shutdown { if ac.state == Shutdown {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
ac.mu.Unlock() ac.mu.Unlock()
return ErrConnClosing return errConnClosing
} }
ac.errorf("transient failure: %v", err) ac.errorf("transient failure: %v", err)
ac.state = TransientFailure ac.state = TransientFailure
@ -589,8 +590,8 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
ac.mu.Lock() ac.mu.Lock()
ac.errorf("connection timeout") ac.errorf("connection timeout")
ac.mu.Unlock() ac.mu.Unlock()
ac.tearDown(ErrClientConnTimeout) ac.tearDown(errClientConnTimeout)
return ErrClientConnTimeout return errClientConnTimeout
} }
closeTransport = false closeTransport = false
select { select {
@ -607,7 +608,7 @@ func (ac *addrConn) resetTransport(closeTransport bool) error {
// ac.tearDown(...) has been invoked. // ac.tearDown(...) has been invoked.
ac.mu.Unlock() ac.mu.Unlock()
newTransport.Close() newTransport.Close()
return ErrConnClosing return errConnClosing
} }
ac.state = Ready ac.state = Ready
ac.stateCV.Broadcast() ac.stateCV.Broadcast()
@ -662,7 +663,7 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
switch { switch {
case ac.state == Shutdown: case ac.state == Shutdown:
ac.mu.Unlock() ac.mu.Unlock()
return nil, ErrConnClosing return nil, errConnClosing
case ac.state == Ready: case ac.state == Ready:
ct := ac.transport ct := ac.transport
ac.mu.Unlock() ac.mu.Unlock()
@ -699,7 +700,7 @@ func (ac *addrConn) tearDown(err error) {
ac.cc.mu.Unlock() ac.cc.mu.Unlock()
}() }()
if ac.down != nil { if ac.down != nil {
ac.down(err) ac.down(downErrorf(false, false, "%v", err))
ac.down = nil ac.down = nil
} }
if ac.state == Shutdown { if ac.state == Shutdown {
@ -716,7 +717,7 @@ func (ac *addrConn) tearDown(err error) {
ac.ready = nil ac.ready = nil
} }
if ac.transport != nil { if ac.transport != nil {
if err == ErrConnDrain { if err == errConnDrain {
ac.transport.GracefulClose() ac.transport.GracefulClose()
} else { } else {
ac.transport.Close() ac.transport.Close()

View File

@ -47,8 +47,8 @@ func TestDialTimeout(t *testing.T) {
if err == nil { if err == nil {
conn.Close() conn.Close()
} }
if err != ErrClientConnTimeout { if err != errClientConnTimeout {
t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) t.Fatalf("Dial(_, _) = %v, %v, want %v", conn, err, errClientConnTimeout)
} }
} }
@ -61,8 +61,8 @@ func TestTLSDialTimeout(t *testing.T) {
if err == nil { if err == nil {
conn.Close() conn.Close()
} }
if err != ErrClientConnTimeout { if err != errClientConnTimeout {
t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, ErrClientConnTimeout) t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, errClientConnTimeout)
} }
} }
@ -72,12 +72,12 @@ func TestCredentialsMisuse(t *testing.T) {
t.Fatalf("Failed to create credentials %v", err) t.Fatalf("Failed to create credentials %v", err)
} }
// Two conflicting credential configurations // Two conflicting credential configurations
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != ErrCredentialsMisuse { if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, ErrCredentialsMisuse) t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
} }
// security info on insecure connection // security info on insecure connection
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != ErrCredentialsMisuse { if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(creds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsMisuse {
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, ErrCredentialsMisuse) t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsMisuse)
} }
} }