mirror of https://github.com/grpc/grpc-go.git
grpc: Move some stats handler calls to gRPC layer, and add local address to peer.Peer (#6716)
This commit is contained in:
parent
6e14274d00
commit
c76d75f4f9
|
@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
|
|||
return nil, errors.New(msg)
|
||||
}
|
||||
|
||||
var localAddr net.Addr
|
||||
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
|
||||
localAddr, _ = la.(net.Addr)
|
||||
}
|
||||
var authInfo credentials.AuthInfo
|
||||
if r.TLS != nil {
|
||||
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
|
||||
}
|
||||
p := peer.Peer{
|
||||
Addr: strAddr(r.RemoteAddr),
|
||||
LocalAddr: localAddr,
|
||||
AuthInfo: authInfo,
|
||||
}
|
||||
st := &serverHandlerTransport{
|
||||
rw: w,
|
||||
req: r,
|
||||
closedCh: make(chan struct{}),
|
||||
writes: make(chan func()),
|
||||
peer: p,
|
||||
contentType: contentType,
|
||||
contentSubtype: contentSubtype,
|
||||
stats: stats,
|
||||
|
@ -134,6 +148,8 @@ type serverHandlerTransport struct {
|
|||
|
||||
headerMD metadata.MD
|
||||
|
||||
peer peer.Peer
|
||||
|
||||
closeOnce sync.Once
|
||||
closedCh chan struct{} // closed on Close
|
||||
|
||||
|
@ -165,7 +181,13 @@ func (ht *serverHandlerTransport) Close(err error) {
|
|||
})
|
||||
}
|
||||
|
||||
func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
|
||||
func (ht *serverHandlerTransport) Peer() *peer.Peer {
|
||||
return &peer.Peer{
|
||||
Addr: ht.peer.Addr,
|
||||
LocalAddr: ht.peer.LocalAddr,
|
||||
AuthInfo: ht.peer.AuthInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
|
||||
// the empty string if unknown.
|
||||
|
@ -347,10 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
||||
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
|
||||
// With this transport type there will be exactly 1 stream: this HTTP request.
|
||||
|
||||
ctx := ht.req.Context()
|
||||
var cancel context.CancelFunc
|
||||
if ht.timeoutSet {
|
||||
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
|
||||
|
@ -370,34 +390,19 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
|||
ht.Close(errors.New("request is done processing"))
|
||||
}()
|
||||
|
||||
req := ht.req
|
||||
|
||||
s := &Stream{
|
||||
id: 0, // irrelevant
|
||||
requestRead: func(int) {},
|
||||
cancel: cancel,
|
||||
buf: newRecvBuffer(),
|
||||
st: ht,
|
||||
method: req.URL.Path,
|
||||
recvCompress: req.Header.Get("grpc-encoding"),
|
||||
contentSubtype: ht.contentSubtype,
|
||||
}
|
||||
pr := &peer.Peer{
|
||||
Addr: ht.RemoteAddr(),
|
||||
}
|
||||
if req.TLS != nil {
|
||||
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
|
||||
}
|
||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||
s.ctx = peer.NewContext(ctx, pr)
|
||||
for _, sh := range ht.stats {
|
||||
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||
inHeader := &stats.InHeader{
|
||||
FullMethod: s.method,
|
||||
RemoteAddr: ht.RemoteAddr(),
|
||||
Compression: s.recvCompress,
|
||||
}
|
||||
sh.HandleRPC(s.ctx, inHeader)
|
||||
req := ht.req
|
||||
s := &Stream{
|
||||
id: 0, // irrelevant
|
||||
ctx: ctx,
|
||||
requestRead: func(int) {},
|
||||
cancel: cancel,
|
||||
buf: newRecvBuffer(),
|
||||
st: ht,
|
||||
method: req.URL.Path,
|
||||
recvCompress: req.Header.Get("grpc-encoding"),
|
||||
contentSubtype: ht.contentSubtype,
|
||||
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
|
||||
}
|
||||
s.trReader = &transportReader{
|
||||
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
|
||||
|
|
|
@ -314,7 +314,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
|
|||
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
||||
}
|
||||
st.ht.HandleStreams(
|
||||
func(s *Stream) { go handleStream(s) },
|
||||
context.Background(), func(s *Stream) { go handleStream(s) },
|
||||
)
|
||||
wantHeader := http.Header{
|
||||
"Date": nil,
|
||||
|
@ -347,7 +347,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
|
|||
st.ht.WriteStatus(s, status.New(statusCode, msg))
|
||||
}
|
||||
st.ht.HandleStreams(
|
||||
func(s *Stream) { go handleStream(s) },
|
||||
context.Background(), func(s *Stream) { go handleStream(s) },
|
||||
)
|
||||
wantHeader := http.Header{
|
||||
"Date": nil,
|
||||
|
@ -396,7 +396,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
|
||||
}
|
||||
ht.HandleStreams(
|
||||
func(s *Stream) { go runStream(s) },
|
||||
context.Background(), func(s *Stream) { go runStream(s) },
|
||||
)
|
||||
wantHeader := http.Header{
|
||||
"Date": nil,
|
||||
|
@ -448,7 +448,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
|
|||
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
|
||||
st := newHandleStreamTest(t)
|
||||
st.ht.HandleStreams(
|
||||
func(s *Stream) { go handleStream(st, s) },
|
||||
context.Background(), func(s *Stream) { go handleStream(st, s) },
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -481,7 +481,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
|
|||
hst.ht.WriteStatus(s, st)
|
||||
}
|
||||
hst.ht.HandleStreams(
|
||||
func(s *Stream) { go handleStream(s) },
|
||||
context.Background(), func(s *Stream) { go handleStream(s) },
|
||||
)
|
||||
wantHeader := http.Header{
|
||||
"Date": nil,
|
||||
|
|
|
@ -69,15 +69,12 @@ var serverConnectionCounter uint64
|
|||
// http2Server implements the ServerTransport interface with HTTP2.
|
||||
type http2Server struct {
|
||||
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
|
||||
ctx context.Context
|
||||
done chan struct{}
|
||||
conn net.Conn
|
||||
loopy *loopyWriter
|
||||
readerDone chan struct{} // sync point to enable testing.
|
||||
writerDone chan struct{} // sync point to enable testing.
|
||||
remoteAddr net.Addr
|
||||
localAddr net.Addr
|
||||
authInfo credentials.AuthInfo // auth info about the connection
|
||||
peer peer.Peer
|
||||
inTapHandle tap.ServerInHandle
|
||||
framer *framer
|
||||
// The max number of concurrent streams.
|
||||
|
@ -243,13 +240,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
|||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
peer := peer.Peer{
|
||||
Addr: conn.RemoteAddr(),
|
||||
LocalAddr: conn.LocalAddr(),
|
||||
AuthInfo: authInfo,
|
||||
}
|
||||
t := &http2Server{
|
||||
ctx: setConnection(context.Background(), rawConn),
|
||||
done: done,
|
||||
conn: conn,
|
||||
remoteAddr: conn.RemoteAddr(),
|
||||
localAddr: conn.LocalAddr(),
|
||||
authInfo: authInfo,
|
||||
peer: peer,
|
||||
framer: framer,
|
||||
readerDone: make(chan struct{}),
|
||||
writerDone: make(chan struct{}),
|
||||
|
@ -267,8 +266,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
|||
bufferPool: newBufferPool(),
|
||||
}
|
||||
t.logger = prefixLoggerForServerTransport(t)
|
||||
// Add peer information to the http2server context.
|
||||
t.ctx = peer.NewContext(t.ctx, t.getPeer())
|
||||
|
||||
t.controlBuf = newControlBuffer(t.done)
|
||||
if dynamicWindow {
|
||||
|
@ -277,15 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
|||
updateFlowControl: t.updateFlowControl,
|
||||
}
|
||||
}
|
||||
for _, sh := range t.stats {
|
||||
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
|
||||
RemoteAddr: t.remoteAddr,
|
||||
LocalAddr: t.localAddr,
|
||||
})
|
||||
connBegin := &stats.ConnBegin{}
|
||||
sh.HandleConn(t.ctx, connBegin)
|
||||
}
|
||||
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
|
||||
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -342,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
|||
|
||||
// operateHeaders takes action on the decoded headers. Returns an error if fatal
|
||||
// error encountered and transport needs to close, otherwise returns nil.
|
||||
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
|
||||
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
|
||||
// Acquire max stream ID lock for entire duration
|
||||
t.maxStreamMu.Lock()
|
||||
defer t.maxStreamMu.Unlock()
|
||||
|
@ -369,10 +358,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
|
||||
buf := newRecvBuffer()
|
||||
s := &Stream{
|
||||
id: streamID,
|
||||
st: t,
|
||||
buf: buf,
|
||||
fc: &inFlow{limit: uint32(t.initialWindowSize)},
|
||||
id: streamID,
|
||||
st: t,
|
||||
buf: buf,
|
||||
fc: &inFlow{limit: uint32(t.initialWindowSize)},
|
||||
headerWireLength: int(frame.Header().Length),
|
||||
}
|
||||
var (
|
||||
// if false, content-type was missing or invalid
|
||||
|
@ -511,9 +501,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
s.state = streamReadDone
|
||||
}
|
||||
if timeoutSet {
|
||||
s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout)
|
||||
s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
|
||||
} else {
|
||||
s.ctx, s.cancel = context.WithCancel(t.ctx)
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
}
|
||||
|
||||
// Attach the received metadata to the context.
|
||||
|
@ -592,18 +582,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
s.requestRead = func(n int) {
|
||||
t.adjustWindow(s, uint32(n))
|
||||
}
|
||||
for _, sh := range t.stats {
|
||||
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||
inHeader := &stats.InHeader{
|
||||
FullMethod: s.method,
|
||||
RemoteAddr: t.remoteAddr,
|
||||
LocalAddr: t.localAddr,
|
||||
Compression: s.recvCompress,
|
||||
WireLength: int(frame.Header().Length),
|
||||
Header: mdata.Copy(),
|
||||
}
|
||||
sh.HandleRPC(s.ctx, inHeader)
|
||||
}
|
||||
s.ctxDone = s.ctx.Done()
|
||||
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
||||
s.trReader = &transportReader{
|
||||
|
@ -629,7 +607,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
// HandleStreams receives incoming streams using the given handler. This is
|
||||
// typically run in a separate goroutine.
|
||||
// traceCtx attaches trace to ctx and returns the new context.
|
||||
func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
||||
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
|
||||
defer close(t.readerDone)
|
||||
for {
|
||||
t.controlBuf.throttle()
|
||||
|
@ -664,7 +642,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
|||
}
|
||||
switch frame := frame.(type) {
|
||||
case *http2.MetaHeadersFrame:
|
||||
if err := t.operateHeaders(frame, handle); err != nil {
|
||||
if err := t.operateHeaders(ctx, frame, handle); err != nil {
|
||||
t.Close(err)
|
||||
break
|
||||
}
|
||||
|
@ -1242,10 +1220,6 @@ func (t *http2Server) Close(err error) {
|
|||
for _, s := range streams {
|
||||
s.cancel()
|
||||
}
|
||||
for _, sh := range t.stats {
|
||||
connEnd := &stats.ConnEnd{}
|
||||
sh.HandleConn(t.ctx, connEnd)
|
||||
}
|
||||
}
|
||||
|
||||
// deleteStream deletes the stream s from transport's active streams.
|
||||
|
@ -1311,10 +1285,6 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
|
|||
})
|
||||
}
|
||||
|
||||
func (t *http2Server) RemoteAddr() net.Addr {
|
||||
return t.remoteAddr
|
||||
}
|
||||
|
||||
func (t *http2Server) Drain(debugData string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
@ -1397,11 +1367,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
|
|||
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
|
||||
LocalFlowControlWindow: int64(t.fc.getSize()),
|
||||
SocketOptions: channelz.GetSocketOption(t.conn),
|
||||
LocalAddr: t.localAddr,
|
||||
RemoteAddr: t.remoteAddr,
|
||||
LocalAddr: t.peer.LocalAddr,
|
||||
RemoteAddr: t.peer.Addr,
|
||||
// RemoteName :
|
||||
}
|
||||
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
|
||||
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
|
||||
s.Security = au.GetSecurityValue()
|
||||
}
|
||||
s.RemoteFlowControlWindow = t.getOutFlowWindow()
|
||||
|
@ -1433,10 +1403,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *http2Server) getPeer() *peer.Peer {
|
||||
// Peer returns the peer of the transport.
|
||||
func (t *http2Server) Peer() *peer.Peer {
|
||||
return &peer.Peer{
|
||||
Addr: t.remoteAddr,
|
||||
AuthInfo: t.authInfo, // Can be nil
|
||||
Addr: t.peer.Addr,
|
||||
LocalAddr: t.peer.LocalAddr,
|
||||
AuthInfo: t.peer.AuthInfo, // Can be nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1461,6 +1433,6 @@ func GetConnection(ctx context.Context) net.Conn {
|
|||
// SetConnection adds the connection to the context to be able to get
|
||||
// information about the destination ip and port for an incoming RPC. This also
|
||||
// allows any unary or streaming interceptors to see the connection.
|
||||
func setConnection(ctx context.Context, conn net.Conn) context.Context {
|
||||
func SetConnection(ctx context.Context, conn net.Conn) context.Context {
|
||||
return context.WithValue(ctx, connectionKey{}, conn)
|
||||
}
|
||||
|
|
|
@ -37,6 +37,7 @@ import (
|
|||
"google.golang.org/grpc/internal/channelz"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/resolver"
|
||||
"google.golang.org/grpc/stats"
|
||||
"google.golang.org/grpc/status"
|
||||
|
@ -265,7 +266,8 @@ type Stream struct {
|
|||
// headerValid indicates whether a valid header was received. Only
|
||||
// meaningful after headerChan is closed (always call waitOnHeader() before
|
||||
// reading its value). Not valid on server side.
|
||||
headerValid bool
|
||||
headerValid bool
|
||||
headerWireLength int // Only set on server side.
|
||||
|
||||
// hdrMu protects header and trailer metadata on the server-side.
|
||||
hdrMu sync.Mutex
|
||||
|
@ -425,6 +427,12 @@ func (s *Stream) Context() context.Context {
|
|||
return s.ctx
|
||||
}
|
||||
|
||||
// SetContext sets the context of the stream. This will be deleted once the
|
||||
// stats handler callouts all move to gRPC layer.
|
||||
func (s *Stream) SetContext(ctx context.Context) {
|
||||
s.ctx = ctx
|
||||
}
|
||||
|
||||
// Method returns the method for the stream.
|
||||
func (s *Stream) Method() string {
|
||||
return s.method
|
||||
|
@ -437,6 +445,12 @@ func (s *Stream) Status() *status.Status {
|
|||
return s.status
|
||||
}
|
||||
|
||||
// HeaderWireLength returns the size of the headers of the stream as received
|
||||
// from the wire. Valid only on the server.
|
||||
func (s *Stream) HeaderWireLength() int {
|
||||
return s.headerWireLength
|
||||
}
|
||||
|
||||
// SetHeader sets the header metadata. This can be called multiple times.
|
||||
// Server side only.
|
||||
// This should not be called in parallel to other data writes.
|
||||
|
@ -698,7 +712,7 @@ type ClientTransport interface {
|
|||
// Write methods for a given Stream will be called serially.
|
||||
type ServerTransport interface {
|
||||
// HandleStreams receives incoming streams using the given handler.
|
||||
HandleStreams(func(*Stream))
|
||||
HandleStreams(context.Context, func(*Stream))
|
||||
|
||||
// WriteHeader sends the header metadata for the given stream.
|
||||
// WriteHeader may not be called on all streams.
|
||||
|
@ -717,8 +731,8 @@ type ServerTransport interface {
|
|||
// handlers will be terminated asynchronously.
|
||||
Close(err error)
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
RemoteAddr() net.Addr
|
||||
// Peer returns the peer of the server transport.
|
||||
Peer() *peer.Peer
|
||||
|
||||
// Drain notifies the client this ServerTransport stops accepting new RPCs.
|
||||
Drain(debugData string)
|
||||
|
|
|
@ -35,8 +35,6 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/peer"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
|
@ -356,19 +354,19 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
|||
s.mu.Unlock()
|
||||
switch ht {
|
||||
case notifyCall:
|
||||
go transport.HandleStreams(h.handleStreamAndNotify)
|
||||
go transport.HandleStreams(context.Background(), h.handleStreamAndNotify)
|
||||
case suspended:
|
||||
go transport.HandleStreams(func(*Stream) {})
|
||||
go transport.HandleStreams(context.Background(), func(*Stream) {})
|
||||
case misbehaved:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||
go h.handleStreamMisbehave(t, s)
|
||||
})
|
||||
case encodingRequiredStatus:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||
go h.handleStreamEncodingRequiredStatus(s)
|
||||
})
|
||||
case invalidHeaderField:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||
go h.handleStreamInvalidHeaderField(s)
|
||||
})
|
||||
case delayRead:
|
||||
|
@ -377,15 +375,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
|||
s.mu.Lock()
|
||||
close(s.ready)
|
||||
s.mu.Unlock()
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||
go h.handleStreamDelayRead(t, s)
|
||||
})
|
||||
case pingpong:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||
go h.handleStreamPingPong(t, s)
|
||||
})
|
||||
default:
|
||||
go transport.HandleStreams(func(s *Stream) {
|
||||
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||
go h.handleStream(t, s)
|
||||
})
|
||||
}
|
||||
|
@ -2594,52 +2592,3 @@ func TestConnectionError_Unwrap(t *testing.T) {
|
|||
t.Error("ConnectionError does not unwrap")
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestPeerSetInServerContext(t *testing.T) {
|
||||
// create client and server transports.
|
||||
server, client, cancel := setUp(t, 0, normal)
|
||||
defer cancel()
|
||||
defer server.stop()
|
||||
defer client.Close(fmt.Errorf("closed manually by test"))
|
||||
|
||||
// create a stream with client transport.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
stream, err := client.NewStream(ctx, &CallHdr{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create a stream: %v", err)
|
||||
}
|
||||
|
||||
waitWhileTrue(t, func() (bool, error) {
|
||||
server.mu.Lock()
|
||||
defer server.mu.Unlock()
|
||||
|
||||
if len(server.conns) == 0 {
|
||||
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
|
||||
}
|
||||
return false, nil
|
||||
})
|
||||
|
||||
// verify peer is set in client transport context.
|
||||
if _, ok := peer.FromContext(client.ctx); !ok {
|
||||
t.Fatalf("Peer expected in client transport's context, but actually not found.")
|
||||
}
|
||||
|
||||
// verify peer is set in stream context.
|
||||
if _, ok := peer.FromContext(stream.ctx); !ok {
|
||||
t.Fatalf("Peer expected in stream context, but actually not found.")
|
||||
}
|
||||
|
||||
// verify peer is set in server transport context.
|
||||
server.mu.Lock()
|
||||
for k := range server.conns {
|
||||
sc, ok := k.(*http2Server)
|
||||
if !ok {
|
||||
t.Fatalf("ServerTransport is of type %T, want %T", k, &http2Server{})
|
||||
}
|
||||
if _, ok = peer.FromContext(sc.ctx); !ok {
|
||||
t.Fatalf("Peer expected in server transport's context, but actually not found.")
|
||||
}
|
||||
}
|
||||
server.mu.Unlock()
|
||||
}
|
||||
|
|
|
@ -28,7 +28,6 @@ import (
|
|||
"net"
|
||||
"strconv"
|
||||
|
||||
v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/authz/audit"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -38,6 +37,8 @@ import (
|
|||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
|
||||
)
|
||||
|
||||
var logger = grpclog.Component("rbac")
|
||||
|
|
|
@ -32,6 +32,8 @@ import (
|
|||
type Peer struct {
|
||||
// Addr is the peer address.
|
||||
Addr net.Addr
|
||||
// LocalAddr is the local address.
|
||||
LocalAddr net.Addr
|
||||
// AuthInfo is the authentication information of the transport.
|
||||
// It is nil if there is no transport security being used.
|
||||
AuthInfo credentials.AuthInfo
|
||||
|
|
45
server.go
45
server.go
|
@ -920,7 +920,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
|
|||
return
|
||||
}
|
||||
go func() {
|
||||
s.serveStreams(st)
|
||||
s.serveStreams(context.Background(), st, rawConn)
|
||||
s.removeConn(lisAddr, st)
|
||||
}()
|
||||
}
|
||||
|
@ -974,12 +974,27 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
|
|||
return st
|
||||
}
|
||||
|
||||
func (s *Server) serveStreams(st transport.ServerTransport) {
|
||||
defer st.Close(errors.New("finished serving streams for the server transport"))
|
||||
var wg sync.WaitGroup
|
||||
func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
|
||||
ctx = transport.SetConnection(ctx, rawConn)
|
||||
ctx = peer.NewContext(ctx, st.Peer())
|
||||
for _, sh := range s.opts.statsHandlers {
|
||||
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
|
||||
RemoteAddr: st.Peer().Addr,
|
||||
LocalAddr: st.Peer().LocalAddr,
|
||||
})
|
||||
sh.HandleConn(ctx, &stats.ConnBegin{})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
st.Close(errors.New("finished serving streams for the server transport"))
|
||||
for _, sh := range s.opts.statsHandlers {
|
||||
sh.HandleConn(ctx, &stats.ConnEnd{})
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
|
||||
st.HandleStreams(func(stream *transport.Stream) {
|
||||
st.HandleStreams(ctx, func(stream *transport.Stream) {
|
||||
wg.Add(1)
|
||||
|
||||
streamQuota.acquire()
|
||||
|
@ -1043,7 +1058,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
defer s.removeConn(listenerAddressForServeHTTP, st)
|
||||
s.serveStreams(st)
|
||||
s.serveStreams(r.Context(), st, nil)
|
||||
}
|
||||
|
||||
func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
|
||||
|
@ -1700,7 +1715,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
|||
tr: tr,
|
||||
firstLine: firstLine{
|
||||
client: false,
|
||||
remoteAddr: t.RemoteAddr(),
|
||||
remoteAddr: t.Peer().Addr,
|
||||
},
|
||||
}
|
||||
if dl, ok := ctx.Deadline(); ok {
|
||||
|
@ -1734,6 +1749,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
|||
service := sm[:pos]
|
||||
method := sm[pos+1:]
|
||||
|
||||
md, _ := metadata.FromIncomingContext(ctx)
|
||||
for _, sh := range s.opts.statsHandlers {
|
||||
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
|
||||
sh.HandleRPC(ctx, &stats.InHeader{
|
||||
FullMethod: stream.Method(),
|
||||
RemoteAddr: t.Peer().Addr,
|
||||
LocalAddr: t.Peer().LocalAddr,
|
||||
Compression: stream.RecvCompress(),
|
||||
WireLength: stream.HeaderWireLength(),
|
||||
Header: md,
|
||||
})
|
||||
}
|
||||
// To have calls in stream callouts work. Will delete once all stats handler
|
||||
// calls come from the gRPC layer.
|
||||
stream.SetContext(ctx)
|
||||
|
||||
srv, knownService := s.services[service]
|
||||
if knownService {
|
||||
if md, ok := srv.methods[method]; ok {
|
||||
|
|
Loading…
Reference in New Issue