diff --git a/call.go b/call.go index 858bf542e..3ef343a7b 100644 --- a/call.go +++ b/call.go @@ -197,8 +197,9 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli for { // TODO(zhaoq): Need a formal spec of fail-fast. callHdr := &transport.CallHdr{ - Host: cc.authority, - Method: method, + Host: cc.authority, + Method: method, + FailFast: c.failFast, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() diff --git a/stats/stats.go b/stats/stats.go index 97ee51cc3..439019dcd 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -147,6 +147,8 @@ type OutHeader struct { LocalAddr net.Addr // Encryption is encrypt method used in the RPC. Encryption string + // Failfast indicates if this RPC is failfast. + FailFast bool } func (s *OutHeader) isStats() {} diff --git a/stats/stats_test.go b/stats/stats_test.go index 15aafc5df..47b336f3a 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -221,21 +221,27 @@ func (te *test) clientConn() *grpc.ClientConn { return te.cc } -func (te *test) doUnaryCall(success bool) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { +type rpcConfig struct { + count int // Number of requests and responses for streaming RPCs. + success bool // Whether the RPC should succeed or return error. + failfast bool +} + +func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { var ( resp *testpb.SimpleResponse req *testpb.SimpleRequest err error ) tc := testpb.NewTestServiceClient(te.clientConn()) - if success { + if c.success { req = &testpb.SimpleRequest{Id: 1} } else { req = &testpb.SimpleRequest{Id: errorID} } ctx := metadata.NewContext(context.Background(), testMetadata) - resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(false)) + resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast)) if err != nil { return req, resp, err } @@ -243,22 +249,22 @@ func (te *test) doUnaryCall(success bool) (*testpb.SimpleRequest, *testpb.Simple return req, resp, err } -func (te *test) doFullDuplexCallRoundtrip(count int, success bool) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { +func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { var ( reqs []*testpb.SimpleRequest resps []*testpb.SimpleResponse err error ) tc := testpb.NewTestServiceClient(te.clientConn()) - stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata)) + stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) if err != nil { return reqs, resps, err } var startID int32 - if !success { + if !c.success { startID = errorID } - for i := 0; i < count; i++ { + for i := 0; i < c.count; i++ { req := &testpb.SimpleRequest{ Id: int32(i) + startID, } @@ -291,6 +297,7 @@ type expectedData struct { respIdx int responses []*testpb.SimpleResponse err error + failfast bool } type gotData struct { @@ -428,6 +435,9 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { if st.Encryption != e.encryption { t.Fatalf("st.Encryption = %v, want %v", st.Encryption, e.encryption) } + if st.FailFast != e.failfast { + t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) + } } } @@ -534,7 +544,7 @@ func TestServerStatsUnaryRPC(t *testing.T) { te.startServer(&testServer{}) defer te.tearDown() - req, resp, err := te.doUnaryCall(true) + req, resp, err := te.doUnaryCall(&rpcConfig{success: true}) if err != nil { t.Fatalf(err.Error()) } @@ -585,7 +595,7 @@ func TestServerStatsUnaryRPCError(t *testing.T) { te.startServer(&testServer{}) defer te.tearDown() - req, resp, err := te.doUnaryCall(false) + req, resp, err := te.doUnaryCall(&rpcConfig{success: false}) if err == nil { t.Fatalf("got error ; want ") } @@ -638,7 +648,7 @@ func TestServerStatsStreamingRPC(t *testing.T) { defer te.tearDown() count := 5 - reqs, resps, err := te.doFullDuplexCallRoundtrip(count, true) + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: true}) if err == nil { t.Fatalf(err.Error()) } @@ -696,7 +706,7 @@ func TestServerStatsStreamingRPCError(t *testing.T) { defer te.tearDown() count := 5 - reqs, resps, err := te.doFullDuplexCallRoundtrip(count, false) + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: false}) if err == nil { t.Fatalf("got error ; want ") } @@ -754,7 +764,8 @@ func TestClientStatsUnaryRPC(t *testing.T) { te.startServer(&testServer{}) defer te.tearDown() - req, resp, err := te.doUnaryCall(true) + failfast := false + req, resp, err := te.doUnaryCall(&rpcConfig{success: true, failfast: failfast}) if err != nil { t.Fatalf(err.Error()) } @@ -765,6 +776,7 @@ func TestClientStatsUnaryRPC(t *testing.T) { serverAddr: te.srvAddr, requests: []*testpb.SimpleRequest{req}, responses: []*testpb.SimpleResponse{resp}, + failfast: failfast, } checkFuncs := map[int]*checkFuncWithCount{ @@ -842,7 +854,8 @@ func TestClientStatsUnaryRPCError(t *testing.T) { te.startServer(&testServer{}) defer te.tearDown() - req, resp, err := te.doUnaryCall(false) + failfast := true + req, resp, err := te.doUnaryCall(&rpcConfig{success: false, failfast: failfast}) if err == nil { t.Fatalf("got error ; want ") } @@ -854,6 +867,7 @@ func TestClientStatsUnaryRPCError(t *testing.T) { requests: []*testpb.SimpleRequest{req}, responses: []*testpb.SimpleResponse{resp}, err: err, + failfast: failfast, } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ @@ -895,7 +909,8 @@ func TestClientStatsStreamingRPC(t *testing.T) { defer te.tearDown() count := 5 - reqs, resps, err := te.doFullDuplexCallRoundtrip(count, true) + failfast := false + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: true, failfast: failfast}) if err == nil { t.Fatalf(err.Error()) } @@ -907,6 +922,7 @@ func TestClientStatsStreamingRPC(t *testing.T) { encryption: "gzip", requests: reqs, responses: resps, + failfast: failfast, } checkFuncs := map[int]*checkFuncWithCount{ @@ -985,7 +1001,8 @@ func TestClientStatsStreamingRPCError(t *testing.T) { defer te.tearDown() count := 5 - reqs, resps, err := te.doFullDuplexCallRoundtrip(count, false) + failfast := true + reqs, resps, err := te.doFullDuplexCallRoundtrip(&rpcConfig{count: count, success: false, failfast: failfast}) if err == nil { t.Fatalf("got error ; want ") } @@ -998,6 +1015,7 @@ func TestClientStatsStreamingRPCError(t *testing.T) { requests: reqs, responses: resps, err: err, + failfast: failfast, } checkFuncs := map[int]*checkFuncWithCount{ diff --git a/stream.go b/stream.go index a1d03c4f1..9d9226a47 100644 --- a/stream.go +++ b/stream.go @@ -127,9 +127,10 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } } callHdr := &transport.CallHdr{ - Host: cc.authority, - Method: method, - Flush: desc.ServerStreams && desc.ClientStreams, + Host: cc.authority, + Method: method, + Flush: desc.ServerStreams && desc.ClientStreams, + FailFast: c.failFast, } if cc.dopts.cp != nil { callHdr.SendCompress = cc.dopts.cp.Type() diff --git a/transport/http2_client.go b/transport/http2_client.go index d7e588657..26f6b509b 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -457,6 +457,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea RemoteAddr: t.RemoteAddr(), LocalAddr: t.LocalAddr(), Encryption: callHdr.SendCompress, + FailFast: callHdr.FailFast, } stats.Handle(s.Context(), outHeader) } diff --git a/transport/transport.go b/transport/transport.go index a78249656..44d774889 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -427,6 +427,9 @@ type CallHdr struct { // only a hint. The transport may modify the flush decision // for performance purposes. Flush bool + + // FailFast indicates whether the RPC is failfast. + FailFast bool } // ClientTransport is the common interface for all gRPC client-side transport