From 9d2ecf553a16eecc7047e20262918a672f8dc4b9 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 28 Jan 2016 19:52:42 +0000 Subject: [PATCH] server: break up the Server.Serve method into some reusable parts Updates grpc/grpc-go#75 --- server.go | 107 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 66 insertions(+), 41 deletions(-) diff --git a/server.go b/server.go index dd8642751..f6ee266c4 100644 --- a/server.go +++ b/server.go @@ -264,49 +264,74 @@ func (s *Server) Serve(lis net.Listener) error { } s.mu.Unlock() + go s.serveNewHTTP2Transport(c, authInfo) + } +} + +func (s *Server) serveNewHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) { + st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) + if err != nil { + s.mu.Lock() + s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) + s.mu.Unlock() + c.Close() + grpclog.Println("grpc: Server.Serve failed to create ServerTransport: ", err) + return + } + if !s.addConn(st) { + c.Close() + return + } + s.serveStreams(st) +} + +func (s *Server) serveStreams(st transport.ServerTransport) { + defer s.removeConn(st) + defer st.Close() + var wg sync.WaitGroup + st.HandleStreams(func(stream *transport.Stream) { + wg.Add(1) go func() { - st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) - if err != nil { - s.mu.Lock() - s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) - s.mu.Unlock() - c.Close() - grpclog.Println("grpc: Server.Serve failed to create ServerTransport: ", err) - return - } - defer st.Close() - s.mu.Lock() - if s.conns == nil { - s.mu.Unlock() - return - } - s.conns[st] = true - s.mu.Unlock() - var wg sync.WaitGroup - st.HandleStreams(func(stream *transport.Stream) { - var trInfo *traceInfo - if EnableTracing { - trInfo = &traceInfo{ - tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()), - } - trInfo.firstLine.client = false - trInfo.firstLine.remoteAddr = st.RemoteAddr() - stream.TraceContext(trInfo.tr) - if dl, ok := stream.Context().Deadline(); ok { - trInfo.firstLine.deadline = dl.Sub(time.Now()) - } - } - wg.Add(1) - go func() { - s.handleStream(st, stream, trInfo) - wg.Done() - }() - }) - wg.Wait() - s.mu.Lock() - delete(s.conns, st) - s.mu.Unlock() + defer wg.Done() + s.handleStream(st, stream, s.traceInfo(st, stream)) }() + }) + wg.Wait() +} + +// traceInfo returns a traceInfo and associates it with stream, if tracing is enabled. +// If tracing is not enabled, it returns nil. +func (s *Server) traceInfo(st transport.ServerTransport, stream *transport.Stream) (trInfo *traceInfo) { + if !EnableTracing { + return nil + } + trInfo = &traceInfo{ + tr: trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()), + } + trInfo.firstLine.client = false + trInfo.firstLine.remoteAddr = st.RemoteAddr() + stream.TraceContext(trInfo.tr) + if dl, ok := stream.Context().Deadline(); ok { + trInfo.firstLine.deadline = dl.Sub(time.Now()) + } + return trInfo +} + +func (s *Server) addConn(st transport.ServerTransport) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.conns == nil { + return false + } + s.conns[st] = true + return true +} + +func (s *Server) removeConn(st transport.ServerTransport) { + s.mu.Lock() + defer s.mu.Unlock() + if s.conns != nil { + delete(s.conns, st) } }