Support fail-fast mode and make it the default

This commit is contained in:
iamqizhao 2016-06-27 14:36:59 -07:00
parent 0eb7c5dcd0
commit 3e71fb360d
7 changed files with 121 additions and 26 deletions

View File

@ -94,10 +94,10 @@ type Balancer interface {
// instead of blocking. // instead of blocking.
// //
// The function returns put which is called once the rpc has completed or failed. // The function returns put which is called once the rpc has completed or failed.
// put can collect and report RPC stats to a remote load balancer. gRPC internals // put can collect and report RPC stats to a remote load balancer.
// will try to call this again if err is non-nil (unless err is ErrClientConnClosing).
// //
// TODO: Add other non-recoverable errors? // This function should only return the errors Balancer cannot recover by itself.
// gRPC internals will fail the RPC if an error is returned.
Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error) Get(ctx context.Context, opts BalancerGetOptions) (addr Address, put func(), err error)
// Notify returns a channel that is used by gRPC internals to watch the addresses // Notify returns a channel that is used by gRPC internals to watch the addresses
// gRPC needs to connect. The addresses might be from a name resolver or remote // gRPC needs to connect. The addresses might be from a name resolver or remote
@ -298,8 +298,20 @@ func (rr *roundRobin) Get(ctx context.Context, opts BalancerGetOptions) (addr Ad
} }
} }
} }
// There is no address available. Wait on rr.waitCh. // There is no address available.
// TODO(zhaoq): Handle the case when opts.BlockingWait is false. if !opts.BlockingWait {
if len(rr.addrs) == 0 {
rr.mu.Unlock()
err = fmt.Errorf("there is no address available")
return
}
// Returns the next addr on rr.addrs for failfast RPCs.
addr = rr.addrs[rr.next].addr
rr.next++
rr.mu.Unlock()
return
}
// Wait on rr.waitCh for non-failfast RPCs.
if rr.waitCh == nil { if rr.waitCh == nil {
ch = make(chan struct{}) ch = make(chan struct{})
rr.waitCh = ch rr.waitCh = ch

View File

@ -239,7 +239,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
t.Fatalf("Failed to create ClientConn: %v", err) t.Fatalf("Failed to create ClientConn: %v", err)
} }
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port) t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want %s", err, servers[0].port)
} }
// Remove the server. // Remove the server.
@ -251,7 +251,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
// Loop until the above update applies. // Loop until the above update applies.
for { for {
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -262,7 +262,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
} }
}() }()
@ -270,7 +270,7 @@ func TestCloseWithPendingRPC(t *testing.T) {
defer wg.Done() defer wg.Done()
var reply string var reply string
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err == nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err == nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want not nil", err)
} }
}() }()
@ -295,7 +295,7 @@ func TestGetOnWaitChannel(t *testing.T) {
for { for {
var reply string var reply string
ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond)
if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc); Code(err) == codes.DeadlineExceeded { if err := Invoke(ctx, "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); Code(err) == codes.DeadlineExceeded {
break break
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
@ -305,7 +305,7 @@ func TestGetOnWaitChannel(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
var reply string var reply string
if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc); err != nil { if err := Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply, cc, FailFast(false)); err != nil {
t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) t.Errorf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err)
} }
}() }()

10
call.go
View File

@ -35,6 +35,7 @@ package grpc
import ( import (
"bytes" "bytes"
//"fmt"
"io" "io"
"time" "time"
@ -101,7 +102,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd
// Invoke is called by generated code. Also users can call Invoke directly when it // Invoke is called by generated code. Also users can call Invoke directly when it
// is really needed in their use cases. // is really needed in their use cases.
func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) {
var c callInfo c := defaultCallInfo
for _, o := range opts { for _, o := range opts {
if err := o.before(&c); err != nil { if err := o.before(&c); err != nil {
return toRPCErr(err) return toRPCErr(err)
@ -165,10 +166,13 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.failFast { if c.failFast {
return toRPCErr(err) return toRPCErr(err)
} }
}
// All the remaining cases are treated as retryable.
continue continue
} }
// ALl the other errors are treated as Internal errors.
return Errorf(codes.Internal, "%v", err)
// All the remaining cases are treated as fatal.
//panic(fmt.Sprintf("ClientConn.getTransport got an unsupported error: %v", err))
}
if c.traceInfo.tr != nil { if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
} }

View File

@ -424,7 +424,6 @@ func (cc *ClientConn) newAddrConn(addr Address, skipWait bool) error {
} }
func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) { func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions) (transport.ClientTransport, func(), error) {
// TODO(zhaoq): Implement fail-fast logic.
addr, put, err := cc.balancer.Get(ctx, opts) addr, put, err := cc.balancer.Get(ctx, opts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -442,7 +441,7 @@ func (cc *ClientConn) getTransport(ctx context.Context, opts BalancerGetOptions)
} }
return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc") return nil, nil, transport.StreamErrorf(codes.Internal, "grpc: failed to find the transport to send the rpc")
} }
t, err := ac.wait(ctx) t, err := ac.wait(ctx, !opts.BlockingWait)
if err != nil { if err != nil {
if put != nil { if put != nil {
put() put()
@ -649,8 +648,9 @@ func (ac *addrConn) transportMonitor() {
} }
} }
// wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed. // wait blocks until i) the new transport is up or ii) ctx is done or iii) ac is closed or
func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error) { // iv) transport is in TransientFailure and the RPC is fail-fast.
func (ac *addrConn) wait(ctx context.Context, failFast bool) (transport.ClientTransport, error) {
for { for {
ac.mu.Lock() ac.mu.Lock()
switch { switch {
@ -662,6 +662,10 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
ac.mu.Unlock() ac.mu.Unlock()
return ct, nil return ct, nil
default: default:
if ac.state == TransientFailure && failFast {
ac.mu.Unlock()
return nil, transport.StreamErrorf(codes.Canceled, "grpc: RPC failed fast due to transport failure")
}
ready := ac.ready ready := ac.ready
if ready == nil { if ready == nil {
ready = make(chan struct{}) ready = make(chan struct{})
@ -673,6 +677,13 @@ func (ac *addrConn) wait(ctx context.Context) (transport.ClientTransport, error)
return nil, transport.ContextErr(ctx.Err()) return nil, transport.ContextErr(ctx.Err())
// Wait until the new transport is ready or failed. // Wait until the new transport is ready or failed.
case <-ready: case <-ready:
ac.mu.Lock()
if ac.state == TransientFailure && failFast {
ac.mu.Unlock()
return nil, transport.StreamErrorf(codes.Canceled, "grpc: RPC failed fast due to transport failure")
}
ac.mu.Unlock()
} }
} }
} }

View File

@ -141,6 +141,8 @@ type callInfo struct {
traceInfo traceInfo // in trace.go traceInfo traceInfo // in trace.go
} }
var defaultCallInfo = callInfo{failFast: true}
// CallOption configures a Call before it starts or extracts information from // CallOption configures a Call before it starts or extracts information from
// a Call after it completes. // a Call after it completes.
type CallOption interface { type CallOption interface {
@ -179,6 +181,18 @@ func Trailer(md *metadata.MD) CallOption {
}) })
} }
// FailFast configures the action to take when an RPC is attempted on broken
// connections or unreachable servers. If failfast is true, the RPC will fail
// immediately. Otherwise, the RPC client will block the call until a
// connection is available (or the call is canceled or times out) and will retry
// the call if it fails due to a transient error.
func FailFast(failFast bool) CallOption {
return beforeCall(func(c *callInfo) error {
c.failFast = failFast
return nil
})
}
// The format of the payload: compressed or not? // The format of the payload: compressed or not?
type payloadFormat uint8 type payloadFormat uint8

View File

@ -105,9 +105,14 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
err error err error
put func() put func()
) )
// TODO(zhaoq): CallOption is omitted. Add support when it is needed. c := defaultCallInfo
for _, o := range opts {
if err := o.before(&c); err != nil {
return nil, toRPCErr(err)
}
}
gopts := BalancerGetOptions{ gopts := BalancerGetOptions{
BlockingWait: false, BlockingWait: !c.failFast,
} }
t, put, err = cc.getTransport(ctx, gopts) t, put, err = cc.getTransport(ctx, gopts)
if err != nil { if err != nil {
@ -122,6 +127,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
callHdr.SendCompress = cc.dopts.cp.Type() callHdr.SendCompress = cc.dopts.cp.Type()
} }
cs := &clientStream{ cs := &clientStream{
opts: opts,
c: c,
desc: desc, desc: desc,
put: put, put: put,
codec: cc.dopts.codec, codec: cc.dopts.codec,
@ -167,6 +174,8 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
// clientStream implements a client side Stream. // clientStream implements a client side Stream.
type clientStream struct { type clientStream struct {
opts []CallOption
c callInfo
t transport.ClientTransport t transport.ClientTransport
s *transport.Stream s *transport.Stream
p *parser p *parser
@ -312,15 +321,18 @@ func (cs *clientStream) closeTransportStream(err error) {
} }
func (cs *clientStream) finish(err error) { func (cs *clientStream) finish(err error) {
if !cs.tracing {
return
}
cs.mu.Lock() cs.mu.Lock()
defer cs.mu.Unlock() defer cs.mu.Unlock()
for _, o := range cs.opts {
o.after(&cs.c)
}
if cs.put != nil { if cs.put != nil {
cs.put() cs.put()
cs.put = nil cs.put = nil
} }
if !cs.tracing {
return
}
if cs.trInfo.tr != nil { if cs.trInfo.tr != nil {
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
cs.trInfo.tr.LazyPrintf("RPC: [OK]") cs.trInfo.tr.LazyPrintf("RPC: [OK]")

View File

@ -550,7 +550,7 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
cc := te.clientConn() cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc) tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil { if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err) t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
} }
te.srv.Stop() te.srv.Stop()
@ -558,12 +558,54 @@ func testTimeoutOnDeadServer(t *testing.T, e env) {
// notification in time the failure path of the 1st invoke of // notification in time the failure path of the 1st invoke of
// ClientConn.wait hits the deadline exceeded error. // ClientConn.wait hits the deadline exceeded error.
ctx, _ := context.WithTimeout(context.Background(), -1) ctx, _ := context.WithTimeout(context.Background(), -1)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded)
} }
awaitNewConnLogOutput() awaitNewConnLogOutput()
} }
func TestFailFast(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testFailFast(t, e)
}
}
func testFailFast(t *testing.T, e env) {
te := newTest(t, e)
te.userAgent = testAppUA
te.declareLogNoise(
"transport: http2Client.notifyError got notified that the client transport was broken EOF",
"grpc: Conn.transportMonitor exits due to: grpc: the client connection is closing",
"grpc: Conn.resetTransport failed to create client transport: connection error",
"grpc: Conn.resetTransport failed to create client transport: connection error: desc = \"transport: dial unix",
)
te.startServer()
defer te.tearDown()
cc := te.clientConn()
tc := testpb.NewTestServiceClient(cc)
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, <nil>", err)
}
// Stop the server and tear down all the exisiting connections.
te.srv.Stop()
// Issue an RPC to make sure the server teardown is propagated to the client already.
ctx, _ := context.WithTimeout(context.Background(), time.Millisecond)
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
t.Fatalf("TestService/EmptyCall(%v, _) = _, %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded)
}
// The client keeps reconnecting and ongoing fail-fast RPCs should fail with code.Canceled.
if _, err := tc.EmptyCall(context.Background(), &testpb.Empty{}); grpc.Code(err) != codes.Canceled {
t.Fatalf("TestService/EmptyCall(_, _, _) = _, %v, want _, error code: %d", err, codes.Canceled)
}
if _, err := tc.StreamingInputCall(ctx); grpc.Code(err) != codes.Canceled {
t.Fatalf("TestService/StreamingInputCall(_) = _, %v, want _, error code: %d", err, codes.Canceled)
}
awaitNewConnLogOutput()
}
func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { func healthCheck(d time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) {
ctx, _ := context.WithTimeout(context.Background(), d) ctx, _ := context.WithTimeout(context.Background(), d)
hc := healthpb.NewHealthClient(cc) hc := healthpb.NewHealthClient(cc)
@ -879,7 +921,7 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup
ResponseSize: proto.Int32(respSize), ResponseSize: proto.Int32(respSize),
Payload: payload, Payload: payload,
} }
reply, err := tc.UnaryCall(context.Background(), req) reply, err := tc.UnaryCall(context.Background(), req, grpc.FailFast(false))
if err != nil { if err != nil {
t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err) t.Errorf("TestService/UnaryCall(_, _) = _, %v, want _, <nil>", err)
return return