From cd8432ec079f38bea795e1e19a1f48a04d0dccb4 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 12 Apr 2017 11:55:54 -0700 Subject: [PATCH] Move handling stats.End to clientStream.finish() (#1182) * move handling stats.End to clientStream.finish() * add stats test for streaming RPC not calling last recv() --- stats/stats_test.go | 33 +++++++++++++++++++++++++++------ stream.go | 41 +++++++++++++++++++++-------------------- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/stats/stats_test.go b/stats/stats_test.go index 6121b432b..c770c151b 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -219,10 +219,11 @@ func (te *test) clientConn() *grpc.ClientConn { } type rpcConfig struct { - count int // Number of requests and responses for streaming RPCs. - success bool // Whether the RPC should succeed or return error. - failfast bool - streaming bool // Whether the rpc should be a streaming RPC. + count int // Number of requests and responses for streaming RPCs. + success bool // Whether the RPC should succeed or return error. + failfast bool + streaming bool // Whether the rpc should be a streaming RPC. + noLastRecv bool // Whether to call recv for io.EOF. When true, last recv won't be called. Only valid for streaming RPCs. } func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { @@ -275,8 +276,14 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest if err = stream.CloseSend(); err != nil && err != io.EOF { return reqs, resps, err } - if _, err = stream.Recv(); err != io.EOF { - return reqs, resps, err + if !c.noLastRecv { + if _, err = stream.Recv(); err != io.EOF { + return reqs, resps, err + } + } else { + // In the case of not calling the last recv, sleep to avoid + // returning too fast to miss the remaining stats (InTrailer and End). + time.Sleep(time.Second) } return reqs, resps, nil @@ -968,6 +975,20 @@ func TestClientStatsStreamingRPC(t *testing.T) { }) } +// If the user doesn't call the last recv() on clientSteam. +func TestClientStatsStreamingRPCNotCallingLastRecv(t *testing.T) { + count := 1 + testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, streaming: true, noLastRecv: true}, map[int]*checkFuncWithCount{ + begin: {checkBegin, 1}, + outHeader: {checkOutHeader, 1}, + outPayload: {checkOutPayload, count}, + inHeader: {checkInHeader, 1}, + inPayload: {checkInPayload, count}, + inTrailer: {checkInTrailer, 1}, + end: {checkEnd, 1}, + }) +} + func TestClientStatsStreamingRPCError(t *testing.T) { count := 5 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, streaming: true}, map[int]*checkFuncWithCount{ diff --git a/stream.go b/stream.go index ecb1a31f6..008ff10eb 100644 --- a/stream.go +++ b/stream.go @@ -271,9 +271,10 @@ type clientStream struct { tracing bool // set to EnableTracing when the clientStream is created. - mu sync.Mutex - put func() - closed bool + mu sync.Mutex + put func() + closed bool + finished bool // trInfo.tr is set when the clientStream is created (if EnableTracing is true), // and is set to nil when the clientStream's finish method is called. trInfo traceInfo @@ -359,21 +360,6 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { - defer func() { - if err != nil && cs.statsHandler != nil { - // Only generate End if err != nil. - // If err == nil, it's not the last RecvMsg. - // The last RecvMsg gets either an RPC error or io.EOF. - end := &stats.End{ - Client: true, - EndTime: time.Now(), - } - if err != io.EOF { - end.Error = toRPCErr(err) - } - cs.statsHandler.HandleRPC(cs.statsCtx, end) - } - }() var inPayload *stats.InPayload if cs.statsHandler != nil { inPayload = &stats.InPayload{ @@ -459,13 +445,17 @@ func (cs *clientStream) closeTransportStream(err error) { } func (cs *clientStream) finish(err error) { + cs.mu.Lock() + defer cs.mu.Unlock() + if cs.finished { + return + } + cs.finished = true defer func() { if cs.cancel != nil { cs.cancel() } }() - cs.mu.Lock() - defer cs.mu.Unlock() for _, o := range cs.opts { o.after(&cs.c) } @@ -473,6 +463,17 @@ func (cs *clientStream) finish(err error) { cs.put() cs.put = nil } + if cs.statsHandler != nil { + end := &stats.End{ + Client: true, + EndTime: time.Now(), + } + if err != io.EOF { + // end.Error is nil if the RPC finished successfully. + end.Error = toRPCErr(err) + } + cs.statsHandler.HandleRPC(cs.statsCtx, end) + } if !cs.tracing { return }