From 6445dedfbc8042d37d1beb4ef5a75ae0b6967517 Mon Sep 17 00:00:00 2001 From: Menghan Li Date: Tue, 8 Nov 2016 18:17:14 -0800 Subject: [PATCH] fix wrong context when handling stats --- call_test.go | 2 ++ server.go | 15 +++++++++++++-- stats/stats_test.go | 30 ++++++++++++++++++++++++++++++ transport/handler_server.go | 2 +- transport/http2_server.go | 8 +++++--- transport/transport.go | 2 +- 6 files changed, 52 insertions(+), 7 deletions(-) diff --git a/call_test.go b/call_test.go index 4c9d1f8a5..3c2165eac 100644 --- a/call_test.go +++ b/call_test.go @@ -185,6 +185,8 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { } go st.HandleStreams(func(s *transport.Stream) { go h.handleStream(t, s) + }, func(ctx context.Context, method string) context.Context { + return ctx }) } } diff --git a/server.go b/server.go index f0ed00f39..762ac33fe 100644 --- a/server.go +++ b/server.go @@ -467,6 +467,12 @@ func (s *Server) serveStreams(st transport.ServerTransport) { defer wg.Done() s.handleStream(st, stream, s.traceInfo(st, stream)) }() + }, func(ctx context.Context, method string) context.Context { + if !EnableTracing { + return ctx + } + tr := trace.New("grpc.Recv."+methodFamily(method), method) + return trace.NewContext(ctx, tr) }) wg.Wait() } @@ -519,12 +525,17 @@ func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Strea if !EnableTracing { return nil } + tr, ok := trace.FromContext(stream.Context()) + if !ok { + grpclog.Fatalf("cannot get trace from context while EnableTracing == true") + } + trInfo = &traceInfo{ - tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()), + tr: tr, } trInfo.firstLine.client = false trInfo.firstLine.remoteAddr = st.RemoteAddr() - stream.TraceContext(trInfo.tr) + if dl, ok := stream.Context().Deadline(); ok { trInfo.firstLine.deadline = dl.Sub(time.Now()) } diff --git a/stats/stats_test.go b/stats/stats_test.go index 8bd54d027..97da256b5 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -592,6 +592,12 @@ func TestServerStatsUnaryRPC(t *testing.T) { t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) } + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + for i, f := range checkFuncs { mu.Lock() f(t, got[i], expect) @@ -645,6 +651,12 @@ func TestServerStatsUnaryRPCError(t *testing.T) { t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) } + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + for i, f := range checkFuncs { mu.Lock() f(t, got[i], expect) @@ -704,6 +716,12 @@ func TestServerStatsStreamingRPC(t *testing.T) { t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) } + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + for i, f := range checkFuncs { mu.Lock() f(t, got[i], expect) @@ -759,6 +777,12 @@ func TestServerStatsStreamingRPCError(t *testing.T) { t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) } + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + for i, f := range checkFuncs { mu.Lock() f(t, got[i], expect) @@ -923,6 +947,12 @@ func TestClientStatsUnaryRPCError(t *testing.T) { t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) } + for i := 0; i < len(got)-1; i++ { + if got[i].ctx != got[i+1].ctx { + t.Fatalf("got different contexts with two stats %T %T", got[i].s, got[i+1].s) + } + } + for i, f := range checkFuncs { mu.Lock() f(t, got[i], expect) diff --git a/transport/handler_server.go b/transport/handler_server.go index 114e34906..10b6dc0b1 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -268,7 +268,7 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error { }) } -func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { +func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), traceCtx func(context.Context, string) context.Context) { // With this transport type there will be exactly 1 stream: this HTTP request. var ctx context.Context diff --git a/transport/http2_server.go b/transport/http2_server.go index c1ac6af6d..8a33e9273 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -151,7 +151,7 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err } // operateHeader takes action on the decoded headers. -func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) (close bool) { +func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (close bool) { buf := newRecvBuffer() s := &Stream{ id: frame.Header().StreamID, @@ -239,6 +239,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.windowHandler = func(n int) { t.updateWindow(s, uint32(n)) } + s.ctx = traceCtx(s.ctx, s.method) if stats.On() { inHeader := &stats.InHeader{ FullMethod: s.method, @@ -255,7 +256,8 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( // HandleStreams receives incoming streams using the given handler. This is // typically run in a separate goroutine. -func (t *http2Server) HandleStreams(handle func(*Stream)) { +// traceCtx attaches trace to ctx and returns the new context. +func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) { // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { @@ -310,7 +312,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) { } switch frame := frame.(type) { case *http2.MetaHeadersFrame: - if t.operateHeaders(frame, handle) { + if t.operateHeaders(frame, handle, traceCtx) { t.Close() break } diff --git a/transport/transport.go b/transport/transport.go index 82a2d7f2e..f782a99b3 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -479,7 +479,7 @@ type ClientTransport interface { // Write methods for a given Stream will be called serially. type ServerTransport interface { // HandleStreams receives incoming streams using the given handler. - HandleStreams(func(*Stream)) + HandleStreams(func(*Stream), func(context.Context, string) context.Context) // WriteHeader sends the header metadata for the given stream. // WriteHeader may not be called on all streams.