From deb01f422a11e3ca17cfe1ec8a3596debda9c473 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Wed, 30 Nov 2016 16:25:46 -0800 Subject: [PATCH] add stats tagger APIs and connection stats. (#992) * add stats.tagger APIs and connection stats. * fix comments use ac.ctx in http2client change name and comments small fixes stats_tests * add a TODO to ConnTagInfo * rename handle to handleRPC * modify stats comments --- call.go | 9 +- server.go | 12 +- stats/handlers.go | 152 ++++++++++++++ stats/stats.go | 70 ++++--- stats/stats_test.go | 421 +++++++++++++++++++++++++++++--------- stream.go | 15 +- transport/http2_client.go | 24 ++- transport/http2_server.go | 25 ++- 8 files changed, 576 insertions(+), 152 deletions(-) create mode 100644 stats/handlers.go diff --git a/call.go b/call.go index 5d9214d15..fc8e18afc 100644 --- a/call.go +++ b/call.go @@ -82,7 +82,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran if inPayload != nil && err == io.EOF && stream.StatusCode() == codes.OK { // TODO in the current implementation, inTrailer may be handled before inPayload in some cases. // Fix the order if necessary. - stats.Handle(ctx, inPayload) + stats.HandleRPC(ctx, inPayload) } c.trailerMD = stream.Trailer() return nil @@ -121,7 +121,7 @@ func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHd err = t.Write(stream, outBuf, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() - stats.Handle(ctx, outPayload) + stats.HandleRPC(ctx, outPayload) } // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following @@ -172,12 +172,13 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli }() } if stats.On() { + ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } - stats.Handle(ctx, begin) + stats.HandleRPC(ctx, begin) } defer func() { if stats.On() { @@ -186,7 +187,7 @@ func invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli EndTime: time.Now(), Error: e, } - stats.Handle(ctx, end) + stats.HandleRPC(ctx, end) } }() topts := &transport.Options{ diff --git a/server.go b/server.go index 3af001ac9..22aa33bfe 100644 --- a/server.go +++ b/server.go @@ -583,7 +583,7 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str err = t.Write(stream, p, opts) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() - stats.Handle(stream.Context(), outPayload) + stats.HandleRPC(stream.Context(), outPayload) } return err } @@ -593,7 +593,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. begin := &stats.Begin{ BeginTime: time.Now(), } - stats.Handle(stream.Context(), begin) + stats.HandleRPC(stream.Context(), begin) } defer func() { if stats.On() { @@ -603,7 +603,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - stats.Handle(stream.Context(), end) + stats.HandleRPC(stream.Context(), end) } }() if trInfo != nil { @@ -698,7 +698,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. inPayload.Payload = v inPayload.Data = req inPayload.Length = len(req) - stats.Handle(stream.Context(), inPayload) + stats.HandleRPC(stream.Context(), inPayload) } if trInfo != nil { trInfo.tr.LazyLog(&payload{sent: false, msg: v}, true) @@ -759,7 +759,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp begin := &stats.Begin{ BeginTime: time.Now(), } - stats.Handle(stream.Context(), begin) + stats.HandleRPC(stream.Context(), begin) } defer func() { if stats.On() { @@ -769,7 +769,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - stats.Handle(stream.Context(), end) + stats.HandleRPC(stream.Context(), end) } }() if s.opts.cp != nil { diff --git a/stats/handlers.go b/stats/handlers.go new file mode 100644 index 000000000..d41c52442 --- /dev/null +++ b/stats/handlers.go @@ -0,0 +1,152 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package stats + +import ( + "net" + "sync/atomic" + + "golang.org/x/net/context" + "google.golang.org/grpc/grpclog" +) + +// ConnTagInfo defines the relevant information needed by connection context tagger. +type ConnTagInfo struct { + // RemoteAddr is the remote address of the corresponding connection. + RemoteAddr net.Addr + // LocalAddr is the local address of the corresponding connection. + LocalAddr net.Addr + // TODO add QOS related fields. +} + +// RPCTagInfo defines the relevant information needed by RPC context tagger. +type RPCTagInfo struct { + // FullMethodName is the RPC method in the format of /package.service/method. + FullMethodName string +} + +var ( + on = new(int32) + rpcHandler func(context.Context, RPCStats) + connHandler func(context.Context, ConnStats) + connTagger func(context.Context, *ConnTagInfo) context.Context + rpcTagger func(context.Context, *RPCTagInfo) context.Context +) + +// HandleRPC processes the RPC stats using the rpc handler registered by the user. +func HandleRPC(ctx context.Context, s RPCStats) { + if rpcHandler == nil { + return + } + rpcHandler(ctx, s) +} + +// RegisterRPCHandler registers the user handler function for RPC stats processing. +// It should be called only once. The later call will overwrite the former value if it is called multiple times. +// This handler function will be called to process the rpc stats. +func RegisterRPCHandler(f func(context.Context, RPCStats)) { + rpcHandler = f +} + +// HandleConn processes the stats using the call back function registered by user. +func HandleConn(ctx context.Context, s ConnStats) { + if connHandler == nil { + return + } + connHandler(ctx, s) +} + +// RegisterConnHandler registers the user handler function for conn stats. +// It should be called only once. The later call will overwrite the former value if it is called multiple times. +// This handler function will be called to process the conn stats. +func RegisterConnHandler(f func(context.Context, ConnStats)) { + connHandler = f +} + +// TagConn calls user registered connection context tagger. +func TagConn(ctx context.Context, info *ConnTagInfo) context.Context { + if connTagger == nil { + return ctx + } + return connTagger(ctx, info) +} + +// RegisterConnTagger registers the user connection context tagger function. +// The connection context tagger can attach some information to the given context. +// The returned context will be used for stats handling. +// For conn stats handling, the context used in connHandler for this +// connection will be derived from the context returned. +// For RPC stats handling, +// - On server side, the context used in rpcHandler for all RPCs on this +// connection will be derived from the context returned. +// - On client side, the context is not derived from the context returned. +func RegisterConnTagger(t func(context.Context, *ConnTagInfo) context.Context) { + connTagger = t +} + +// TagRPC calls the user registered RPC context tagger. +func TagRPC(ctx context.Context, info *RPCTagInfo) context.Context { + if rpcTagger == nil { + return ctx + } + return rpcTagger(ctx, info) +} + +// RegisterRPCTagger registers the user RPC context tagger function. +// The RPC context tagger can attach some information to the given context. +// The context used in stats rpcHandler for this RPC will be derived from the +// context returned. +func RegisterRPCTagger(t func(context.Context, *RPCTagInfo) context.Context) { + rpcTagger = t +} + +// Start starts the stats collection and processing if there is a registered stats handle. +func Start() { + if rpcHandler == nil && connHandler == nil { + grpclog.Println("rpcHandler and connHandler are both nil when starting stats. Stats is not started") + return + } + atomic.StoreInt32(on, 1) +} + +// Stop stops the stats collection and processing. +// Stop does not unregister the handlers. +func Stop() { + atomic.StoreInt32(on, 0) +} + +// On indicates whether the stats collection and processing is on. +func On() bool { + return atomic.CompareAndSwapInt32(on, 1, 1) +} diff --git a/stats/stats.go b/stats/stats.go index 4b030d985..a82448a68 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -38,16 +38,12 @@ package stats // import "google.golang.org/grpc/stats" import ( "net" - "sync/atomic" "time" - - "golang.org/x/net/context" - "google.golang.org/grpc/grpclog" ) // RPCStats contains stats information about RPCs. -// All stats types in this package implements this interface. type RPCStats interface { + isRPCStats() // IsClient returns true if this RPCStats is from client side. IsClient() bool } @@ -66,6 +62,8 @@ type Begin struct { // IsClient indicates if this is from client side. func (s *Begin) IsClient() bool { return s.Client } +func (s *Begin) isRPCStats() {} + // InPayload contains the information for an incoming payload. type InPayload struct { // Client is true if this InPayload is from client side. @@ -85,6 +83,8 @@ type InPayload struct { // IsClient indicates if this is from client side. func (s *InPayload) IsClient() bool { return s.Client } +func (s *InPayload) isRPCStats() {} + // InHeader contains stats when a header is received. // FullMethod, addresses and Compression are only valid if Client is false. type InHeader struct { @@ -106,6 +106,8 @@ type InHeader struct { // IsClient indicates if this is from client side. func (s *InHeader) IsClient() bool { return s.Client } +func (s *InHeader) isRPCStats() {} + // InTrailer contains stats when a trailer is received. type InTrailer struct { // Client is true if this InTrailer is from client side. @@ -117,6 +119,8 @@ type InTrailer struct { // IsClient indicates if this is from client side. func (s *InTrailer) IsClient() bool { return s.Client } +func (s *InTrailer) isRPCStats() {} + // OutPayload contains the information for an outgoing payload. type OutPayload struct { // Client is true if this OutPayload is from client side. @@ -136,6 +140,8 @@ type OutPayload struct { // IsClient indicates if this is from client side. func (s *OutPayload) IsClient() bool { return s.Client } +func (s *OutPayload) isRPCStats() {} + // OutHeader contains stats when a header is sent. // FullMethod, addresses and Compression are only valid if Client is true. type OutHeader struct { @@ -157,6 +163,8 @@ type OutHeader struct { // IsClient indicates if this is from client side. func (s *OutHeader) IsClient() bool { return s.Client } +func (s *OutHeader) isRPCStats() {} + // OutTrailer contains stats when a trailer is sent. type OutTrailer struct { // Client is true if this OutTrailer is from client side. @@ -168,6 +176,8 @@ type OutTrailer struct { // IsClient indicates if this is from client side. func (s *OutTrailer) IsClient() bool { return s.Client } +func (s *OutTrailer) isRPCStats() {} + // End contains stats when an RPC ends. type End struct { // Client is true if this End is from client side. @@ -181,39 +191,33 @@ type End struct { // IsClient indicates if this is from client side. func (s *End) IsClient() bool { return s.Client } -var ( - on = new(int32) - handler func(context.Context, RPCStats) -) +func (s *End) isRPCStats() {} -// On indicates whether stats is started. -func On() bool { - return atomic.CompareAndSwapInt32(on, 1, 1) +// ConnStats contains stats information about connections. +type ConnStats interface { + isConnStats() + // IsClient returns true if this ConnStats is from client side. + IsClient() bool } -// Handle processes the stats using the call back function registered by user. -func Handle(ctx context.Context, s RPCStats) { - handler(ctx, s) +// ConnBegin contains the stats of a connection when it is established. +type ConnBegin struct { + // Client is true if this ConnBegin is from client side. + Client bool } -// RegisterHandler registers the user handler function. -// If another handler was registered before, this new handler will overwrite the old one. -// This handler function will be called to process the stats. -func RegisterHandler(f func(context.Context, RPCStats)) { - handler = f +// IsClient indicates if this is from client side. +func (s *ConnBegin) IsClient() bool { return s.Client } + +func (s *ConnBegin) isConnStats() {} + +// ConnEnd contains the stats of a connection when it ends. +type ConnEnd struct { + // Client is true if this ConnEnd is from client side. + Client bool } -// Start starts the stats collection and reporting if there is a registered stats handle. -func Start() { - if handler == nil { - grpclog.Println("handler is nil when starting stats. Stats is not started") - return - } - atomic.StoreInt32(on, 1) -} +// IsClient indicates if this is from client side. +func (s *ConnEnd) IsClient() bool { return s.Client } -// Stop stops the stats collection and processing. -// Stop does not unregister handler. -func Stop() { - atomic.StoreInt32(on, 0) -} +func (s *ConnEnd) isConnStats() {} diff --git a/stats/stats_test.go b/stats/stats_test.go index e904810bd..1761e7944 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -49,26 +49,87 @@ import ( testpb "google.golang.org/grpc/stats/grpc_testing" ) +func init() { + grpc.EnableTracing = false +} + func TestStartStop(t *testing.T) { - stats.RegisterHandler(nil) + stats.RegisterRPCHandler(nil) + stats.RegisterConnHandler(nil) stats.Start() - if stats.On() != false { + if stats.On() { t.Fatalf("stats.Start() with nil handler, stats.On() = true, want false") } - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) {}) - if stats.On() != false { - t.Fatalf("after stats.RegisterHandler(), stats.On() = true, want false") - } + + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {}) + stats.RegisterConnHandler(nil) stats.Start() - if stats.On() != true { - t.Fatalf("after stats.Start(_), stats.On() = false, want true") + if !stats.On() { + t.Fatalf("stats.Start() with non-nil handler, stats.On() = false, want true") } stats.Stop() - if stats.On() != false { + + stats.RegisterRPCHandler(nil) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {}) + stats.Start() + if !stats.On() { + t.Fatalf("stats.Start() with non-nil conn handler, stats.On() = false, want true") + } + stats.Stop() + + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) {}) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) {}) + if stats.On() { + t.Fatalf("after stats.RegisterRPCHandler(), stats.On() = true, want false") + } + stats.Start() + if !stats.On() { + t.Fatalf("after stats.Start(_), stats.On() = false, want true") + } + + stats.Stop() + if stats.On() { t.Fatalf("after stats.Stop(), stats.On() = true, want false") } } +type connCtxKey struct{} +type rpcCtxKey struct{} + +func TestTagConnCtx(t *testing.T) { + defer stats.RegisterConnTagger(nil) + ctx1 := context.Background() + stats.RegisterConnTagger(nil) + ctx2 := stats.TagConn(ctx1, nil) + if ctx2 != ctx1 { + t.Fatalf("nil conn ctx tagger should not modify context, got %v; want %v", ctx2, ctx1) + } + stats.RegisterConnTagger(func(ctx context.Context, info *stats.ConnTagInfo) context.Context { + return context.WithValue(ctx, connCtxKey{}, "connctxvalue") + }) + ctx3 := stats.TagConn(ctx1, nil) + if v, ok := ctx3.Value(connCtxKey{}).(string); !ok || v != "connctxvalue" { + t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, connCtxKey{}, "connctxvalue")) + } +} + +func TestTagRPCCtx(t *testing.T) { + defer stats.RegisterRPCTagger(nil) + ctx1 := context.Background() + stats.RegisterRPCTagger(nil) + ctx2 := stats.TagRPC(ctx1, nil) + if ctx2 != ctx1 { + t.Fatalf("nil rpc ctx tagger should not modify context, got %v; want %v", ctx2, ctx1) + } + stats.RegisterRPCTagger(func(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return context.WithValue(ctx, rpcCtxKey{}, "rpcctxvalue") + }) + ctx3 := stats.TagRPC(ctx1, nil) + if v, ok := ctx3.Value(rpcCtxKey{}).(string); !ok || v != "rpcctxvalue" { + t.Fatalf("got context %v; want %v", ctx3, context.WithValue(ctx1, rpcCtxKey{}, "rpcctxvalue")) + } +} + var ( // For headers: testMetadata = metadata.MD{ @@ -242,10 +303,6 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple ctx := metadata.NewContext(context.Background(), testMetadata) resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast)) - if err != nil { - return req, resp, err - } - return req, resp, err } @@ -303,7 +360,7 @@ type expectedData struct { type gotData struct { ctx context.Context client bool - s stats.RPCStats + s interface{} // This could be RPCStats or ConnStats. } const ( @@ -315,6 +372,8 @@ const ( outPayload outHeader outTrailer + connbegin + connend ) func checkBegin(t *testing.T, d *gotData, e *expectedData) { @@ -363,6 +422,24 @@ func checkInHeader(t *testing.T, d *gotData, e *expectedData) { if st.Compression != e.compression { t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) } + + if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok { + if connInfo.RemoteAddr != st.RemoteAddr { + t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr) + } + if connInfo.LocalAddr != st.LocalAddr { + t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr) + } + } else { + t.Fatalf("got context %v, want one with connCtxKey", d.ctx) + } + if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { + if rpcInfo.FullMethodName != st.FullMethod { + t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) + } + } else { + t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) + } } } @@ -451,11 +528,19 @@ func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) } if st.RemoteAddr.String() != e.serverAddr { - t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) + t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr) } if st.Compression != e.compression { t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) } + + if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { + if rpcInfo.FullMethodName != st.FullMethod { + t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) + } + } else { + t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) + } } } @@ -546,14 +631,91 @@ func checkEnd(t *testing.T, d *gotData, e *expectedData) { } } -func TestServerStatsUnaryRPC(t *testing.T) { - var got []*gotData +func checkConnBegin(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.ConnBegin + ) + if st, ok = d.s.(*stats.ConnBegin); !ok { + t.Fatalf("got %T, want ConnBegin", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + st.IsClient() // TODO remove this. +} - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { +func checkConnEnd(t *testing.T, d *gotData, e *expectedData) { + var ( + ok bool + st *stats.ConnEnd + ) + if st, ok = d.s.(*stats.ConnEnd); !ok { + t.Fatalf("got %T, want ConnEnd", d.s) + } + if d.ctx == nil { + t.Fatalf("d.ctx = nil, want ") + } + st.IsClient() // TODO remove this. +} + +func tagConnCtx(ctx context.Context, info *stats.ConnTagInfo) context.Context { + return context.WithValue(ctx, connCtxKey{}, info) +} + +func tagRPCCtx(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return context.WithValue(ctx, rpcCtxKey{}, info) +} + +func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { + if len(got) != len(checkFuncs) { + t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) + } + + var ( + rpcctx context.Context + connctx context.Context + ) + for i := 0; i < len(got); i++ { + if _, ok := got[i].s.(stats.RPCStats); ok { + if rpcctx != nil && got[i].ctx != rpcctx { + t.Fatalf("got different contexts with stats %T", got[i].s) + } + rpcctx = got[i].ctx + } else { + if connctx != nil && got[i].ctx != connctx { + t.Fatalf("got different contexts with stats %T", got[i].s) + } + connctx = got[i].ctx + } + } + + for i, f := range checkFuncs { + f(t, got[i], expect) + } +} + +func TestServerStatsUnaryRPC(t *testing.T) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() if !s.IsClient() { got = append(got, &gotData{ctx, false, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -575,6 +737,7 @@ func TestServerStatsUnaryRPC(t *testing.T) { } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkConnBegin, checkInHeader, checkBegin, checkInPayload, @@ -582,30 +745,33 @@ func TestServerStatsUnaryRPC(t *testing.T) { checkOutPayload, checkOutTrailer, checkEnd, + checkConnEnd, } - if len(got) != len(checkFuncs) { - t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) - } - - for i := 0; i < len(got)-1; i++ { - if got[i].ctx != got[i+1].ctx { - t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) - } - } - - for i, f := range checkFuncs { - f(t, got[i], expect) - } + checkServerStats(t, got, expect, checkFuncs) } func TestServerStatsUnaryRPCError(t *testing.T) { - var got []*gotData - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() if !s.IsClient() { got = append(got, &gotData{ctx, false, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -628,36 +794,40 @@ func TestServerStatsUnaryRPCError(t *testing.T) { } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkConnBegin, checkInHeader, checkBegin, checkInPayload, checkOutHeader, checkOutTrailer, checkEnd, + checkConnEnd, } - if len(got) != len(checkFuncs) { - t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) - } - - for i := 0; i < len(got)-1; i++ { - if got[i].ctx != got[i+1].ctx { - t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) - } - } - - for i, f := range checkFuncs { - f(t, got[i], expect) - } + checkServerStats(t, got, expect, checkFuncs) } func TestServerStatsStreamingRPC(t *testing.T) { - var got []*gotData - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() if !s.IsClient() { got = append(got, &gotData{ctx, false, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -681,6 +851,7 @@ func TestServerStatsStreamingRPC(t *testing.T) { } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkConnBegin, checkInHeader, checkBegin, checkOutHeader, @@ -692,31 +863,36 @@ func TestServerStatsStreamingRPC(t *testing.T) { for i := 0; i < count; i++ { checkFuncs = append(checkFuncs, ioPayFuncs...) } - checkFuncs = append(checkFuncs, checkOutTrailer, checkEnd) + checkFuncs = append(checkFuncs, + checkOutTrailer, + checkEnd, + checkConnEnd, + ) - if len(got) != len(checkFuncs) { - t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) - } - - for i := 0; i < len(got)-1; i++ { - if got[i].ctx != got[i+1].ctx { - t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) - } - } - - for i, f := range checkFuncs { - f(t, got[i], expect) - } + checkServerStats(t, got, expect, checkFuncs) } func TestServerStatsStreamingRPCError(t *testing.T) { - var got []*gotData - - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + var ( + mu sync.Mutex + got []*gotData + ) + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { + mu.Lock() + defer mu.Unlock() if !s.IsClient() { got = append(got, &gotData{ctx, false, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if !s.IsClient() { + got = append(got, &gotData{ctx, false, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -741,27 +917,17 @@ func TestServerStatsStreamingRPCError(t *testing.T) { } checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ + checkConnBegin, checkInHeader, checkBegin, checkOutHeader, checkInPayload, checkOutTrailer, checkEnd, + checkConnEnd, } - if len(got) != len(checkFuncs) { - t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) - } - - for i := 0; i < len(got)-1; i++ { - if got[i].ctx != got[i+1].ctx { - t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) - } - } - - for i, f := range checkFuncs { - f(t, got[i], expect) - } + checkServerStats(t, got, expect, checkFuncs) } type checkFuncWithCount struct { @@ -778,9 +944,21 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF t.Fatalf("got %v stats, want %v stats", len(got), expectLen) } - for i := 0; i < len(got)-1; i++ { - if got[i].ctx != got[i+1].ctx { - t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + var ( + rpcctx context.Context + connctx context.Context + ) + for i := 0; i < len(got); i++ { + if _, ok := got[i].s.(stats.RPCStats); ok { + if rpcctx != nil && got[i].ctx != rpcctx { + t.Fatalf("got different contexts with stats %T", got[i].s) + } + rpcctx = got[i].ctx + } else { + if connctx != nil && got[i].ctx != connctx { + t.Fatalf("got different contexts with stats %T", got[i].s) + } + connctx = got[i].ctx } } @@ -788,48 +966,60 @@ func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkF switch s.s.(type) { case *stats.Begin: if checkFuncs[begin].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[begin].f(t, s, expect) checkFuncs[begin].c-- case *stats.OutHeader: if checkFuncs[outHeader].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[outHeader].f(t, s, expect) checkFuncs[outHeader].c-- case *stats.OutPayload: if checkFuncs[outPayload].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[outPayload].f(t, s, expect) checkFuncs[outPayload].c-- case *stats.InHeader: if checkFuncs[inHeader].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[inHeader].f(t, s, expect) checkFuncs[inHeader].c-- case *stats.InPayload: if checkFuncs[inPayload].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[inPayload].f(t, s, expect) checkFuncs[inPayload].c-- case *stats.InTrailer: if checkFuncs[inTrailer].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[inTrailer].f(t, s, expect) checkFuncs[inTrailer].c-- case *stats.End: if checkFuncs[end].c <= 0 { - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } checkFuncs[end].f(t, s, expect) checkFuncs[end].c-- + case *stats.ConnBegin: + if checkFuncs[connbegin].c <= 0 { + t.Fatalf("unexpected stats: %T", s.s) + } + checkFuncs[connbegin].f(t, s, expect) + checkFuncs[connbegin].c-- + case *stats.ConnEnd: + if checkFuncs[connend].c <= 0 { + t.Fatalf("unexpected stats: %T", s.s) + } + checkFuncs[connend].f(t, s, expect) + checkFuncs[connend].c-- default: - t.Fatalf("unexpected stats: %T", s) + t.Fatalf("unexpected stats: %T", s.s) } } } @@ -839,13 +1029,22 @@ func TestClientStatsUnaryRPC(t *testing.T) { mu sync.Mutex got []*gotData ) - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { mu.Lock() defer mu.Unlock() if s.IsClient() { got = append(got, &gotData{ctx, true, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -869,6 +1068,7 @@ func TestClientStatsUnaryRPC(t *testing.T) { } checkFuncs := map[int]*checkFuncWithCount{ + connbegin: {checkConnBegin, 1}, begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, 1}, @@ -876,6 +1076,7 @@ func TestClientStatsUnaryRPC(t *testing.T) { inPayload: {checkInPayload, 1}, inTrailer: {checkInTrailer, 1}, end: {checkEnd, 1}, + connend: {checkConnEnd, 1}, } checkClientStats(t, got, expect, checkFuncs) @@ -886,13 +1087,22 @@ func TestClientStatsUnaryRPCError(t *testing.T) { mu sync.Mutex got []*gotData ) - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { mu.Lock() defer mu.Unlock() if s.IsClient() { got = append(got, &gotData{ctx, true, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -917,12 +1127,14 @@ func TestClientStatsUnaryRPCError(t *testing.T) { } checkFuncs := map[int]*checkFuncWithCount{ + connbegin: {checkConnBegin, 1}, begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, 1}, inHeader: {checkInHeader, 1}, inTrailer: {checkInTrailer, 1}, end: {checkEnd, 1}, + connend: {checkConnEnd, 1}, } checkClientStats(t, got, expect, checkFuncs) @@ -933,14 +1145,22 @@ func TestClientStatsStreamingRPC(t *testing.T) { mu sync.Mutex got []*gotData ) - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { mu.Lock() defer mu.Unlock() if s.IsClient() { - // t.Logf(" == %T %v", s, s.IsClient()) got = append(got, &gotData{ctx, true, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -966,6 +1186,7 @@ func TestClientStatsStreamingRPC(t *testing.T) { } checkFuncs := map[int]*checkFuncWithCount{ + connbegin: {checkConnBegin, 1}, begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, count}, @@ -973,6 +1194,7 @@ func TestClientStatsStreamingRPC(t *testing.T) { inPayload: {checkInPayload, count}, inTrailer: {checkInTrailer, 1}, end: {checkEnd, 1}, + connend: {checkConnEnd, 1}, } checkClientStats(t, got, expect, checkFuncs) @@ -983,13 +1205,22 @@ func TestClientStatsStreamingRPCError(t *testing.T) { mu sync.Mutex got []*gotData ) - stats.RegisterHandler(func(ctx context.Context, s stats.RPCStats) { + stats.RegisterRPCHandler(func(ctx context.Context, s stats.RPCStats) { mu.Lock() defer mu.Unlock() if s.IsClient() { got = append(got, &gotData{ctx, true, s}) } }) + stats.RegisterConnHandler(func(ctx context.Context, s stats.ConnStats) { + mu.Lock() + defer mu.Unlock() + if s.IsClient() { + got = append(got, &gotData{ctx, true, s}) + } + }) + stats.RegisterConnTagger(tagConnCtx) + stats.RegisterRPCTagger(tagRPCCtx) stats.Start() defer stats.Stop() @@ -1016,12 +1247,14 @@ func TestClientStatsStreamingRPCError(t *testing.T) { } checkFuncs := map[int]*checkFuncWithCount{ + connbegin: {checkConnBegin, 1}, begin: {checkBegin, 1}, outHeader: {checkOutHeader, 1}, outPayload: {checkOutPayload, 1}, inHeader: {checkInHeader, 1}, inTrailer: {checkInTrailer, 1}, end: {checkEnd, 1}, + connend: {checkConnEnd, 1}, } checkClientStats(t, got, expect, checkFuncs) diff --git a/stream.go b/stream.go index 95c8acf8d..1bcd2183a 100644 --- a/stream.go +++ b/stream.go @@ -145,12 +145,13 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth }() } if stats.On() { + ctx = stats.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method}) begin := &stats.Begin{ Client: true, BeginTime: time.Now(), FailFast: c.failFast, } - stats.Handle(ctx, begin) + stats.HandleRPC(ctx, begin) } defer func() { if err != nil && stats.On() { @@ -159,7 +160,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth Client: true, Error: err, } - stats.Handle(ctx, end) + stats.HandleRPC(ctx, end) } }() gopts := BalancerGetOptions{ @@ -342,7 +343,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { err = cs.t.Write(cs.s, out, &transport.Options{Last: false}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() - stats.Handle(cs.statsCtx, outPayload) + stats.HandleRPC(cs.statsCtx, outPayload) } return err } @@ -360,7 +361,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { if err != io.EOF { end.Error = toRPCErr(err) } - stats.Handle(cs.statsCtx, end) + stats.HandleRPC(cs.statsCtx, end) } }() var inPayload *stats.InPayload @@ -385,7 +386,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.mu.Unlock() } if inPayload != nil { - stats.Handle(cs.statsCtx, inPayload) + stats.HandleRPC(cs.statsCtx, inPayload) } if !cs.desc.ClientStreams || cs.desc.ServerStreams { return @@ -565,7 +566,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } if outPayload != nil { outPayload.SentTime = time.Now() - stats.Handle(ss.s.Context(), outPayload) + stats.HandleRPC(ss.s.Context(), outPayload) } return nil } @@ -599,7 +600,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { return toRPCErr(err) } if inPayload != nil { - stats.Handle(ss.s.Context(), inPayload) + stats.HandleRPC(ss.s.Context(), inPayload) } return nil } diff --git a/transport/http2_client.go b/transport/http2_client.go index cbd9f3260..5640aea80 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -56,6 +56,7 @@ import ( // http2Client implements the ClientTransport interface with HTTP2. type http2Client struct { + ctx context.Context target string // server name/addr userAgent string md interface{} @@ -181,6 +182,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } var buf bytes.Buffer t := &http2Client{ + ctx: ctx, target: addr.Addr, userAgent: ua, md: addr.Metadata, @@ -242,6 +244,16 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) ( } go t.controller() t.writableChan <- 0 + if stats.On() { + t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{ + Client: true, + } + stats.HandleConn(t.ctx, connBegin) + } return t, nil } @@ -467,7 +479,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea LocalAddr: t.localAddr, Compression: callHdr.SendCompress, } - stats.Handle(s.clientStatsCtx, outHeader) + stats.HandleRPC(s.clientStatsCtx, outHeader) } t.writableChan <- 0 return s, nil @@ -547,6 +559,12 @@ func (t *http2Client) Close() (err error) { s.mu.Unlock() s.write(recvMsg{err: ErrConnClosing}) } + if stats.On() { + connEnd := &stats.ConnEnd{ + Client: true, + } + stats.HandleConn(t.ctx, connEnd) + } return } @@ -904,13 +922,13 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { Client: true, WireLength: int(frame.Header().Length), } - stats.Handle(s.clientStatsCtx, inHeader) + stats.HandleRPC(s.clientStatsCtx, inHeader) } else { inTrailer := &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), } - stats.Handle(s.clientStatsCtx, inTrailer) + stats.HandleRPC(s.clientStatsCtx, inTrailer) } } }() diff --git a/transport/http2_server.go b/transport/http2_server.go index db9beb90a..62ea3037a 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -60,6 +60,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe // http2Server implements the ServerTransport interface with HTTP2. type http2Server struct { + ctx context.Context conn net.Conn remoteAddr net.Addr localAddr net.Addr @@ -127,6 +128,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } var buf bytes.Buffer t := &http2Server{ + ctx: context.Background(), conn: conn, remoteAddr: conn.RemoteAddr(), localAddr: conn.LocalAddr(), @@ -145,6 +147,14 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err activeStreams: make(map[uint32]*Stream), streamSendQuota: defaultWindowSize, } + if stats.On() { + t.ctx = stats.TagConn(t.ctx, &stats.ConnTagInfo{ + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + }) + connBegin := &stats.ConnBegin{} + stats.HandleConn(t.ctx, connBegin) + } go t.controller() t.writableChan <- 0 return t, nil @@ -177,9 +187,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } s.recvCompress = state.encoding if state.timeoutSet { - s.ctx, s.cancel = context.WithTimeout(context.TODO(), state.timeout) + s.ctx, s.cancel = context.WithTimeout(t.ctx, state.timeout) } else { - s.ctx, s.cancel = context.WithCancel(context.TODO()) + s.ctx, s.cancel = context.WithCancel(t.ctx) } pr := &peer.Peer{ Addr: t.remoteAddr, @@ -241,6 +251,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } s.ctx = traceCtx(s.ctx, s.method) if stats.On() { + s.ctx = stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) inHeader := &stats.InHeader{ FullMethod: s.method, RemoteAddr: t.remoteAddr, @@ -248,7 +259,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( Compression: s.recvCompress, WireLength: int(frame.Header().Length), } - stats.Handle(s.ctx, inHeader) + stats.HandleRPC(s.ctx, inHeader) } handle(s) return @@ -533,7 +544,7 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { outHeader := &stats.OutHeader{ WireLength: bufLen, } - stats.Handle(s.Context(), outHeader) + stats.HandleRPC(s.Context(), outHeader) } t.writableChan <- 0 return nil @@ -596,7 +607,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s outTrailer := &stats.OutTrailer{ WireLength: bufLen, } - stats.Handle(s.Context(), outTrailer) + stats.HandleRPC(s.Context(), outTrailer) } t.closeStream(s) t.writableChan <- 0 @@ -783,6 +794,10 @@ func (t *http2Server) Close() (err error) { for _, s := range streams { s.cancel() } + if stats.On() { + connEnd := &stats.ConnEnd{} + stats.HandleConn(t.ctx, connEnd) + } return }