grpc: Move some stats handler calls to gRPC layer, and add local address to peer.Peer (#6716)

This commit is contained in:
Zach Reyes 2023-10-25 18:01:05 -04:00 committed by GitHub
parent 6e14274d00
commit c76d75f4f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 162 deletions

View File

@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
return nil, errors.New(msg) return nil, errors.New(msg)
} }
var localAddr net.Addr
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr, _ = la.(net.Addr)
}
var authInfo credentials.AuthInfo
if r.TLS != nil {
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
p := peer.Peer{
Addr: strAddr(r.RemoteAddr),
LocalAddr: localAddr,
AuthInfo: authInfo,
}
st := &serverHandlerTransport{ st := &serverHandlerTransport{
rw: w, rw: w,
req: r, req: r,
closedCh: make(chan struct{}), closedCh: make(chan struct{}),
writes: make(chan func()), writes: make(chan func()),
peer: p,
contentType: contentType, contentType: contentType,
contentSubtype: contentSubtype, contentSubtype: contentSubtype,
stats: stats, stats: stats,
@ -134,6 +148,8 @@ type serverHandlerTransport struct {
headerMD metadata.MD headerMD metadata.MD
peer peer.Peer
closeOnce sync.Once closeOnce sync.Once
closedCh chan struct{} // closed on Close closedCh chan struct{} // closed on Close
@ -165,7 +181,13 @@ func (ht *serverHandlerTransport) Close(err error) {
}) })
} }
func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) } func (ht *serverHandlerTransport) Peer() *peer.Peer {
return &peer.Peer{
Addr: ht.peer.Addr,
LocalAddr: ht.peer.LocalAddr,
AuthInfo: ht.peer.AuthInfo,
}
}
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or // strAddr is a net.Addr backed by either a TCP "ip:port" string, or
// the empty string if unknown. // the empty string if unknown.
@ -347,10 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err return err
} }
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) { func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request. // With this transport type there will be exactly 1 stream: this HTTP request.
ctx := ht.req.Context()
var cancel context.CancelFunc var cancel context.CancelFunc
if ht.timeoutSet { if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout) ctx, cancel = context.WithTimeout(ctx, ht.timeout)
@ -370,34 +390,19 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
ht.Close(errors.New("request is done processing")) ht.Close(errors.New("request is done processing"))
}() }()
req := ht.req
s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr) req := ht.req
for _, sh := range ht.stats { s := &Stream{
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method}) id: 0, // irrelevant
inHeader := &stats.InHeader{ ctx: ctx,
FullMethod: s.method, requestRead: func(int) {},
RemoteAddr: ht.RemoteAddr(), cancel: cancel,
Compression: s.recvCompress, buf: newRecvBuffer(),
} st: ht,
sh.HandleRPC(s.ctx, inHeader) method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
} }
s.trReader = &transportReader{ s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}}, reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},

View File

@ -314,7 +314,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st.ht.WriteStatus(s, status.New(codes.OK, "")) st.ht.WriteStatus(s, status.New(codes.OK, ""))
} }
st.ht.HandleStreams( st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) }, context.Background(), func(s *Stream) { go handleStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,
@ -347,7 +347,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
st.ht.WriteStatus(s, status.New(statusCode, msg)) st.ht.WriteStatus(s, status.New(statusCode, msg))
} }
st.ht.HandleStreams( st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) }, context.Background(), func(s *Stream) { go handleStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,
@ -396,7 +396,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow")) ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
} }
ht.HandleStreams( ht.HandleStreams(
func(s *Stream) { go runStream(s) }, context.Background(), func(s *Stream) { go runStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,
@ -448,7 +448,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) { func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
st := newHandleStreamTest(t) st := newHandleStreamTest(t)
st.ht.HandleStreams( st.ht.HandleStreams(
func(s *Stream) { go handleStream(st, s) }, context.Background(), func(s *Stream) { go handleStream(st, s) },
) )
} }
@ -481,7 +481,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
hst.ht.WriteStatus(s, st) hst.ht.WriteStatus(s, st)
} }
hst.ht.HandleStreams( hst.ht.HandleStreams(
func(s *Stream) { go handleStream(s) }, context.Background(), func(s *Stream) { go handleStream(s) },
) )
wantHeader := http.Header{ wantHeader := http.Header{
"Date": nil, "Date": nil,

View File

@ -69,15 +69,12 @@ var serverConnectionCounter uint64
// http2Server implements the ServerTransport interface with HTTP2. // http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct { type http2Server struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically. lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
done chan struct{} done chan struct{}
conn net.Conn conn net.Conn
loopy *loopyWriter loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing. readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing. writerDone chan struct{} // sync point to enable testing.
remoteAddr net.Addr peer peer.Peer
localAddr net.Addr
authInfo credentials.AuthInfo // auth info about the connection
inTapHandle tap.ServerInHandle inTapHandle tap.ServerInHandle
framer *framer framer *framer
// The max number of concurrent streams. // The max number of concurrent streams.
@ -243,13 +240,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
} }
done := make(chan struct{}) done := make(chan struct{})
peer := peer.Peer{
Addr: conn.RemoteAddr(),
LocalAddr: conn.LocalAddr(),
AuthInfo: authInfo,
}
t := &http2Server{ t := &http2Server{
ctx: setConnection(context.Background(), rawConn),
done: done, done: done,
conn: conn, conn: conn,
remoteAddr: conn.RemoteAddr(), peer: peer,
localAddr: conn.LocalAddr(),
authInfo: authInfo,
framer: framer, framer: framer,
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
writerDone: make(chan struct{}), writerDone: make(chan struct{}),
@ -267,8 +266,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
bufferPool: newBufferPool(), bufferPool: newBufferPool(),
} }
t.logger = prefixLoggerForServerTransport(t) t.logger = prefixLoggerForServerTransport(t)
// Add peer information to the http2server context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())
t.controlBuf = newControlBuffer(t.done) t.controlBuf = newControlBuffer(t.done)
if dynamicWindow { if dynamicWindow {
@ -277,15 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl, updateFlowControl: t.updateFlowControl,
} }
} }
for _, sh := range t.stats { t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
})
connBegin := &stats.ConnBegin{}
sh.HandleConn(t.ctx, connBegin)
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -342,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
// operateHeaders takes action on the decoded headers. Returns an error if fatal // operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil. // error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error { func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
// Acquire max stream ID lock for entire duration // Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock() t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock() defer t.maxStreamMu.Unlock()
@ -369,10 +358,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
buf := newRecvBuffer() buf := newRecvBuffer()
s := &Stream{ s := &Stream{
id: streamID, id: streamID,
st: t, st: t,
buf: buf, buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)}, fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
} }
var ( var (
// if false, content-type was missing or invalid // if false, content-type was missing or invalid
@ -511,9 +501,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.state = streamReadDone s.state = streamReadDone
} }
if timeoutSet { if timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout) s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
} else { } else {
s.ctx, s.cancel = context.WithCancel(t.ctx) s.ctx, s.cancel = context.WithCancel(ctx)
} }
// Attach the received metadata to the context. // Attach the received metadata to the context.
@ -592,18 +582,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) { s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n)) t.adjustWindow(s, uint32(n))
} }
for _, sh := range t.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
Header: mdata.Copy(),
}
sh.HandleRPC(s.ctx, inHeader)
}
s.ctxDone = s.ctx.Done() s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{ s.trReader = &transportReader{
@ -629,7 +607,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// HandleStreams receives incoming streams using the given handler. This is // HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine. // typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context. // traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream)) { func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
defer close(t.readerDone) defer close(t.readerDone)
for { for {
t.controlBuf.throttle() t.controlBuf.throttle()
@ -664,7 +642,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
} }
switch frame := frame.(type) { switch frame := frame.(type) {
case *http2.MetaHeadersFrame: case *http2.MetaHeadersFrame:
if err := t.operateHeaders(frame, handle); err != nil { if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err) t.Close(err)
break break
} }
@ -1242,10 +1220,6 @@ func (t *http2Server) Close(err error) {
for _, s := range streams { for _, s := range streams {
s.cancel() s.cancel()
} }
for _, sh := range t.stats {
connEnd := &stats.ConnEnd{}
sh.HandleConn(t.ctx, connEnd)
}
} }
// deleteStream deletes the stream s from transport's active streams. // deleteStream deletes the stream s from transport's active streams.
@ -1311,10 +1285,6 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
}) })
} }
func (t *http2Server) RemoteAddr() net.Addr {
return t.remoteAddr
}
func (t *http2Server) Drain(debugData string) { func (t *http2Server) Drain(debugData string) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@ -1397,11 +1367,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)), LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()), LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn), SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.localAddr, LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.remoteAddr, RemoteAddr: t.peer.Addr,
// RemoteName : // RemoteName :
} }
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok { if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue() s.Security = au.GetSecurityValue()
} }
s.RemoteFlowControlWindow = t.getOutFlowWindow() s.RemoteFlowControlWindow = t.getOutFlowWindow()
@ -1433,10 +1403,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
} }
} }
func (t *http2Server) getPeer() *peer.Peer { // Peer returns the peer of the transport.
func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{ return &peer.Peer{
Addr: t.remoteAddr, Addr: t.peer.Addr,
AuthInfo: t.authInfo, // Can be nil LocalAddr: t.peer.LocalAddr,
AuthInfo: t.peer.AuthInfo, // Can be nil
} }
} }
@ -1461,6 +1433,6 @@ func GetConnection(ctx context.Context) net.Conn {
// SetConnection adds the connection to the context to be able to get // SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also // information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection. // allows any unary or streaming interceptors to see the connection.
func setConnection(ctx context.Context, conn net.Conn) context.Context { func SetConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn) return context.WithValue(ctx, connectionKey{}, conn)
} }

View File

@ -37,6 +37,7 @@ import (
"google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -265,7 +266,8 @@ type Stream struct {
// headerValid indicates whether a valid header was received. Only // headerValid indicates whether a valid header was received. Only
// meaningful after headerChan is closed (always call waitOnHeader() before // meaningful after headerChan is closed (always call waitOnHeader() before
// reading its value). Not valid on server side. // reading its value). Not valid on server side.
headerValid bool headerValid bool
headerWireLength int // Only set on server side.
// hdrMu protects header and trailer metadata on the server-side. // hdrMu protects header and trailer metadata on the server-side.
hdrMu sync.Mutex hdrMu sync.Mutex
@ -425,6 +427,12 @@ func (s *Stream) Context() context.Context {
return s.ctx return s.ctx
} }
// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *Stream) SetContext(ctx context.Context) {
s.ctx = ctx
}
// Method returns the method for the stream. // Method returns the method for the stream.
func (s *Stream) Method() string { func (s *Stream) Method() string {
return s.method return s.method
@ -437,6 +445,12 @@ func (s *Stream) Status() *status.Status {
return s.status return s.status
} }
// HeaderWireLength returns the size of the headers of the stream as received
// from the wire. Valid only on the server.
func (s *Stream) HeaderWireLength() int {
return s.headerWireLength
}
// SetHeader sets the header metadata. This can be called multiple times. // SetHeader sets the header metadata. This can be called multiple times.
// Server side only. // Server side only.
// This should not be called in parallel to other data writes. // This should not be called in parallel to other data writes.
@ -698,7 +712,7 @@ type ClientTransport interface {
// Write methods for a given Stream will be called serially. // Write methods for a given Stream will be called serially.
type ServerTransport interface { type ServerTransport interface {
// HandleStreams receives incoming streams using the given handler. // HandleStreams receives incoming streams using the given handler.
HandleStreams(func(*Stream)) HandleStreams(context.Context, func(*Stream))
// WriteHeader sends the header metadata for the given stream. // WriteHeader sends the header metadata for the given stream.
// WriteHeader may not be called on all streams. // WriteHeader may not be called on all streams.
@ -717,8 +731,8 @@ type ServerTransport interface {
// handlers will be terminated asynchronously. // handlers will be terminated asynchronously.
Close(err error) Close(err error)
// RemoteAddr returns the remote network address. // Peer returns the peer of the server transport.
RemoteAddr() net.Addr Peer() *peer.Peer
// Drain notifies the client this ServerTransport stops accepting new RPCs. // Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain(debugData string) Drain(debugData string)

View File

@ -35,8 +35,6 @@ import (
"testing" "testing"
"time" "time"
"google.golang.org/grpc/peer"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
@ -356,19 +354,19 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Unlock() s.mu.Unlock()
switch ht { switch ht {
case notifyCall: case notifyCall:
go transport.HandleStreams(h.handleStreamAndNotify) go transport.HandleStreams(context.Background(), h.handleStreamAndNotify)
case suspended: case suspended:
go transport.HandleStreams(func(*Stream) {}) go transport.HandleStreams(context.Background(), func(*Stream) {})
case misbehaved: case misbehaved:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(context.Background(), func(s *Stream) {
go h.handleStreamMisbehave(t, s) go h.handleStreamMisbehave(t, s)
}) })
case encodingRequiredStatus: case encodingRequiredStatus:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(context.Background(), func(s *Stream) {
go h.handleStreamEncodingRequiredStatus(s) go h.handleStreamEncodingRequiredStatus(s)
}) })
case invalidHeaderField: case invalidHeaderField:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(context.Background(), func(s *Stream) {
go h.handleStreamInvalidHeaderField(s) go h.handleStreamInvalidHeaderField(s)
}) })
case delayRead: case delayRead:
@ -377,15 +375,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
s.mu.Lock() s.mu.Lock()
close(s.ready) close(s.ready)
s.mu.Unlock() s.mu.Unlock()
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(context.Background(), func(s *Stream) {
go h.handleStreamDelayRead(t, s) go h.handleStreamDelayRead(t, s)
}) })
case pingpong: case pingpong:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(context.Background(), func(s *Stream) {
go h.handleStreamPingPong(t, s) go h.handleStreamPingPong(t, s)
}) })
default: default:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(context.Background(), func(s *Stream) {
go h.handleStream(t, s) go h.handleStream(t, s)
}) })
} }
@ -2594,52 +2592,3 @@ func TestConnectionError_Unwrap(t *testing.T) {
t.Error("ConnectionError does not unwrap") t.Error("ConnectionError does not unwrap")
} }
} }
func (s) TestPeerSetInServerContext(t *testing.T) {
// create client and server transports.
server, client, cancel := setUp(t, 0, normal)
defer cancel()
defer server.stop()
defer client.Close(fmt.Errorf("closed manually by test"))
// create a stream with client transport.
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
stream, err := client.NewStream(ctx, &CallHdr{})
if err != nil {
t.Fatalf("failed to create a stream: %v", err)
}
waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.conns) == 0 {
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
}
return false, nil
})
// verify peer is set in client transport context.
if _, ok := peer.FromContext(client.ctx); !ok {
t.Fatalf("Peer expected in client transport's context, but actually not found.")
}
// verify peer is set in stream context.
if _, ok := peer.FromContext(stream.ctx); !ok {
t.Fatalf("Peer expected in stream context, but actually not found.")
}
// verify peer is set in server transport context.
server.mu.Lock()
for k := range server.conns {
sc, ok := k.(*http2Server)
if !ok {
t.Fatalf("ServerTransport is of type %T, want %T", k, &http2Server{})
}
if _, ok = peer.FromContext(sc.ctx); !ok {
t.Fatalf("Peer expected in server transport's context, but actually not found.")
}
}
server.mu.Unlock()
}

View File

@ -28,7 +28,6 @@ import (
"net" "net"
"strconv" "strconv"
v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/authz/audit" "google.golang.org/grpc/authz/audit"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@ -38,6 +37,8 @@ import (
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
) )
var logger = grpclog.Component("rbac") var logger = grpclog.Component("rbac")

View File

@ -32,6 +32,8 @@ import (
type Peer struct { type Peer struct {
// Addr is the peer address. // Addr is the peer address.
Addr net.Addr Addr net.Addr
// LocalAddr is the local address.
LocalAddr net.Addr
// AuthInfo is the authentication information of the transport. // AuthInfo is the authentication information of the transport.
// It is nil if there is no transport security being used. // It is nil if there is no transport security being used.
AuthInfo credentials.AuthInfo AuthInfo credentials.AuthInfo

View File

@ -920,7 +920,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return return
} }
go func() { go func() {
s.serveStreams(st) s.serveStreams(context.Background(), st, rawConn)
s.removeConn(lisAddr, st) s.removeConn(lisAddr, st)
}() }()
} }
@ -974,12 +974,27 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
return st return st
} }
func (s *Server) serveStreams(st transport.ServerTransport) { func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
defer st.Close(errors.New("finished serving streams for the server transport")) ctx = transport.SetConnection(ctx, rawConn)
var wg sync.WaitGroup ctx = peer.NewContext(ctx, st.Peer())
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
RemoteAddr: st.Peer().Addr,
LocalAddr: st.Peer().LocalAddr,
})
sh.HandleConn(ctx, &stats.ConnBegin{})
}
defer func() {
st.Close(errors.New("finished serving streams for the server transport"))
for _, sh := range s.opts.statsHandlers {
sh.HandleConn(ctx, &stats.ConnEnd{})
}
}()
var wg sync.WaitGroup
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams) streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
st.HandleStreams(func(stream *transport.Stream) { st.HandleStreams(ctx, func(stream *transport.Stream) {
wg.Add(1) wg.Add(1)
streamQuota.acquire() streamQuota.acquire()
@ -1043,7 +1058,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
defer s.removeConn(listenerAddressForServeHTTP, st) defer s.removeConn(listenerAddressForServeHTTP, st)
s.serveStreams(st) s.serveStreams(r.Context(), st, nil)
} }
func (s *Server) addConn(addr string, st transport.ServerTransport) bool { func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
@ -1700,7 +1715,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
tr: tr, tr: tr,
firstLine: firstLine{ firstLine: firstLine{
client: false, client: false,
remoteAddr: t.RemoteAddr(), remoteAddr: t.Peer().Addr,
}, },
} }
if dl, ok := ctx.Deadline(); ok { if dl, ok := ctx.Deadline(); ok {
@ -1734,6 +1749,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
service := sm[:pos] service := sm[:pos]
method := sm[pos+1:] method := sm[pos+1:]
md, _ := metadata.FromIncomingContext(ctx)
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
sh.HandleRPC(ctx, &stats.InHeader{
FullMethod: stream.Method(),
RemoteAddr: t.Peer().Addr,
LocalAddr: t.Peer().LocalAddr,
Compression: stream.RecvCompress(),
WireLength: stream.HeaderWireLength(),
Header: md,
})
}
// To have calls in stream callouts work. Will delete once all stats handler
// calls come from the gRPC layer.
stream.SetContext(ctx)
srv, knownService := s.services[service] srv, knownService := s.services[service]
if knownService { if knownService {
if md, ok := srv.methods[method]; ok { if md, ok := srv.methods[method]; ok {