mirror of https://github.com/grpc/grpc-go.git
grpc: Move some stats handler calls to gRPC layer, and add local address to peer.Peer (#6716)
This commit is contained in:
parent
6e14274d00
commit
c76d75f4f9
|
@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
|
||||||
return nil, errors.New(msg)
|
return nil, errors.New(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var localAddr net.Addr
|
||||||
|
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
|
||||||
|
localAddr, _ = la.(net.Addr)
|
||||||
|
}
|
||||||
|
var authInfo credentials.AuthInfo
|
||||||
|
if r.TLS != nil {
|
||||||
|
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
|
||||||
|
}
|
||||||
|
p := peer.Peer{
|
||||||
|
Addr: strAddr(r.RemoteAddr),
|
||||||
|
LocalAddr: localAddr,
|
||||||
|
AuthInfo: authInfo,
|
||||||
|
}
|
||||||
st := &serverHandlerTransport{
|
st := &serverHandlerTransport{
|
||||||
rw: w,
|
rw: w,
|
||||||
req: r,
|
req: r,
|
||||||
closedCh: make(chan struct{}),
|
closedCh: make(chan struct{}),
|
||||||
writes: make(chan func()),
|
writes: make(chan func()),
|
||||||
|
peer: p,
|
||||||
contentType: contentType,
|
contentType: contentType,
|
||||||
contentSubtype: contentSubtype,
|
contentSubtype: contentSubtype,
|
||||||
stats: stats,
|
stats: stats,
|
||||||
|
@ -134,6 +148,8 @@ type serverHandlerTransport struct {
|
||||||
|
|
||||||
headerMD metadata.MD
|
headerMD metadata.MD
|
||||||
|
|
||||||
|
peer peer.Peer
|
||||||
|
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
closedCh chan struct{} // closed on Close
|
closedCh chan struct{} // closed on Close
|
||||||
|
|
||||||
|
@ -165,7 +181,13 @@ func (ht *serverHandlerTransport) Close(err error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
|
func (ht *serverHandlerTransport) Peer() *peer.Peer {
|
||||||
|
return &peer.Peer{
|
||||||
|
Addr: ht.peer.Addr,
|
||||||
|
LocalAddr: ht.peer.LocalAddr,
|
||||||
|
AuthInfo: ht.peer.AuthInfo,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
|
// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
|
||||||
// the empty string if unknown.
|
// the empty string if unknown.
|
||||||
|
@ -347,10 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
|
||||||
// With this transport type there will be exactly 1 stream: this HTTP request.
|
// With this transport type there will be exactly 1 stream: this HTTP request.
|
||||||
|
|
||||||
ctx := ht.req.Context()
|
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
if ht.timeoutSet {
|
if ht.timeoutSet {
|
||||||
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
|
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
|
||||||
|
@ -370,34 +390,19 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
|
||||||
ht.Close(errors.New("request is done processing"))
|
ht.Close(errors.New("request is done processing"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req := ht.req
|
|
||||||
|
|
||||||
s := &Stream{
|
|
||||||
id: 0, // irrelevant
|
|
||||||
requestRead: func(int) {},
|
|
||||||
cancel: cancel,
|
|
||||||
buf: newRecvBuffer(),
|
|
||||||
st: ht,
|
|
||||||
method: req.URL.Path,
|
|
||||||
recvCompress: req.Header.Get("grpc-encoding"),
|
|
||||||
contentSubtype: ht.contentSubtype,
|
|
||||||
}
|
|
||||||
pr := &peer.Peer{
|
|
||||||
Addr: ht.RemoteAddr(),
|
|
||||||
}
|
|
||||||
if req.TLS != nil {
|
|
||||||
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
|
|
||||||
}
|
|
||||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||||
s.ctx = peer.NewContext(ctx, pr)
|
req := ht.req
|
||||||
for _, sh := range ht.stats {
|
s := &Stream{
|
||||||
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
id: 0, // irrelevant
|
||||||
inHeader := &stats.InHeader{
|
ctx: ctx,
|
||||||
FullMethod: s.method,
|
requestRead: func(int) {},
|
||||||
RemoteAddr: ht.RemoteAddr(),
|
cancel: cancel,
|
||||||
Compression: s.recvCompress,
|
buf: newRecvBuffer(),
|
||||||
}
|
st: ht,
|
||||||
sh.HandleRPC(s.ctx, inHeader)
|
method: req.URL.Path,
|
||||||
|
recvCompress: req.Header.Get("grpc-encoding"),
|
||||||
|
contentSubtype: ht.contentSubtype,
|
||||||
|
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
|
||||||
}
|
}
|
||||||
s.trReader = &transportReader{
|
s.trReader = &transportReader{
|
||||||
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
|
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
|
||||||
|
|
|
@ -314,7 +314,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
|
||||||
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
||||||
}
|
}
|
||||||
st.ht.HandleStreams(
|
st.ht.HandleStreams(
|
||||||
func(s *Stream) { go handleStream(s) },
|
context.Background(), func(s *Stream) { go handleStream(s) },
|
||||||
)
|
)
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
|
@ -347,7 +347,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
|
||||||
st.ht.WriteStatus(s, status.New(statusCode, msg))
|
st.ht.WriteStatus(s, status.New(statusCode, msg))
|
||||||
}
|
}
|
||||||
st.ht.HandleStreams(
|
st.ht.HandleStreams(
|
||||||
func(s *Stream) { go handleStream(s) },
|
context.Background(), func(s *Stream) { go handleStream(s) },
|
||||||
)
|
)
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
|
@ -396,7 +396,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
||||||
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
|
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
|
||||||
}
|
}
|
||||||
ht.HandleStreams(
|
ht.HandleStreams(
|
||||||
func(s *Stream) { go runStream(s) },
|
context.Background(), func(s *Stream) { go runStream(s) },
|
||||||
)
|
)
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
|
@ -448,7 +448,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
|
||||||
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
|
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
|
||||||
st := newHandleStreamTest(t)
|
st := newHandleStreamTest(t)
|
||||||
st.ht.HandleStreams(
|
st.ht.HandleStreams(
|
||||||
func(s *Stream) { go handleStream(st, s) },
|
context.Background(), func(s *Stream) { go handleStream(st, s) },
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -481,7 +481,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
|
||||||
hst.ht.WriteStatus(s, st)
|
hst.ht.WriteStatus(s, st)
|
||||||
}
|
}
|
||||||
hst.ht.HandleStreams(
|
hst.ht.HandleStreams(
|
||||||
func(s *Stream) { go handleStream(s) },
|
context.Background(), func(s *Stream) { go handleStream(s) },
|
||||||
)
|
)
|
||||||
wantHeader := http.Header{
|
wantHeader := http.Header{
|
||||||
"Date": nil,
|
"Date": nil,
|
||||||
|
|
|
@ -69,15 +69,12 @@ var serverConnectionCounter uint64
|
||||||
// http2Server implements the ServerTransport interface with HTTP2.
|
// http2Server implements the ServerTransport interface with HTTP2.
|
||||||
type http2Server struct {
|
type http2Server struct {
|
||||||
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
|
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
|
||||||
ctx context.Context
|
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
loopy *loopyWriter
|
loopy *loopyWriter
|
||||||
readerDone chan struct{} // sync point to enable testing.
|
readerDone chan struct{} // sync point to enable testing.
|
||||||
writerDone chan struct{} // sync point to enable testing.
|
writerDone chan struct{} // sync point to enable testing.
|
||||||
remoteAddr net.Addr
|
peer peer.Peer
|
||||||
localAddr net.Addr
|
|
||||||
authInfo credentials.AuthInfo // auth info about the connection
|
|
||||||
inTapHandle tap.ServerInHandle
|
inTapHandle tap.ServerInHandle
|
||||||
framer *framer
|
framer *framer
|
||||||
// The max number of concurrent streams.
|
// The max number of concurrent streams.
|
||||||
|
@ -243,13 +240,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
peer := peer.Peer{
|
||||||
|
Addr: conn.RemoteAddr(),
|
||||||
|
LocalAddr: conn.LocalAddr(),
|
||||||
|
AuthInfo: authInfo,
|
||||||
|
}
|
||||||
t := &http2Server{
|
t := &http2Server{
|
||||||
ctx: setConnection(context.Background(), rawConn),
|
|
||||||
done: done,
|
done: done,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
remoteAddr: conn.RemoteAddr(),
|
peer: peer,
|
||||||
localAddr: conn.LocalAddr(),
|
|
||||||
authInfo: authInfo,
|
|
||||||
framer: framer,
|
framer: framer,
|
||||||
readerDone: make(chan struct{}),
|
readerDone: make(chan struct{}),
|
||||||
writerDone: make(chan struct{}),
|
writerDone: make(chan struct{}),
|
||||||
|
@ -267,8 +266,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
||||||
bufferPool: newBufferPool(),
|
bufferPool: newBufferPool(),
|
||||||
}
|
}
|
||||||
t.logger = prefixLoggerForServerTransport(t)
|
t.logger = prefixLoggerForServerTransport(t)
|
||||||
// Add peer information to the http2server context.
|
|
||||||
t.ctx = peer.NewContext(t.ctx, t.getPeer())
|
|
||||||
|
|
||||||
t.controlBuf = newControlBuffer(t.done)
|
t.controlBuf = newControlBuffer(t.done)
|
||||||
if dynamicWindow {
|
if dynamicWindow {
|
||||||
|
@ -277,15 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
||||||
updateFlowControl: t.updateFlowControl,
|
updateFlowControl: t.updateFlowControl,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, sh := range t.stats {
|
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
|
||||||
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
|
|
||||||
RemoteAddr: t.remoteAddr,
|
|
||||||
LocalAddr: t.localAddr,
|
|
||||||
})
|
|
||||||
connBegin := &stats.ConnBegin{}
|
|
||||||
sh.HandleConn(t.ctx, connBegin)
|
|
||||||
}
|
|
||||||
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -342,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
|
||||||
|
|
||||||
// operateHeaders takes action on the decoded headers. Returns an error if fatal
|
// operateHeaders takes action on the decoded headers. Returns an error if fatal
|
||||||
// error encountered and transport needs to close, otherwise returns nil.
|
// error encountered and transport needs to close, otherwise returns nil.
|
||||||
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
|
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
|
||||||
// Acquire max stream ID lock for entire duration
|
// Acquire max stream ID lock for entire duration
|
||||||
t.maxStreamMu.Lock()
|
t.maxStreamMu.Lock()
|
||||||
defer t.maxStreamMu.Unlock()
|
defer t.maxStreamMu.Unlock()
|
||||||
|
@ -369,10 +358,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||||
|
|
||||||
buf := newRecvBuffer()
|
buf := newRecvBuffer()
|
||||||
s := &Stream{
|
s := &Stream{
|
||||||
id: streamID,
|
id: streamID,
|
||||||
st: t,
|
st: t,
|
||||||
buf: buf,
|
buf: buf,
|
||||||
fc: &inFlow{limit: uint32(t.initialWindowSize)},
|
fc: &inFlow{limit: uint32(t.initialWindowSize)},
|
||||||
|
headerWireLength: int(frame.Header().Length),
|
||||||
}
|
}
|
||||||
var (
|
var (
|
||||||
// if false, content-type was missing or invalid
|
// if false, content-type was missing or invalid
|
||||||
|
@ -511,9 +501,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||||
s.state = streamReadDone
|
s.state = streamReadDone
|
||||||
}
|
}
|
||||||
if timeoutSet {
|
if timeoutSet {
|
||||||
s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout)
|
s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
|
||||||
} else {
|
} else {
|
||||||
s.ctx, s.cancel = context.WithCancel(t.ctx)
|
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attach the received metadata to the context.
|
// Attach the received metadata to the context.
|
||||||
|
@ -592,18 +582,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||||
s.requestRead = func(n int) {
|
s.requestRead = func(n int) {
|
||||||
t.adjustWindow(s, uint32(n))
|
t.adjustWindow(s, uint32(n))
|
||||||
}
|
}
|
||||||
for _, sh := range t.stats {
|
|
||||||
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
|
||||||
inHeader := &stats.InHeader{
|
|
||||||
FullMethod: s.method,
|
|
||||||
RemoteAddr: t.remoteAddr,
|
|
||||||
LocalAddr: t.localAddr,
|
|
||||||
Compression: s.recvCompress,
|
|
||||||
WireLength: int(frame.Header().Length),
|
|
||||||
Header: mdata.Copy(),
|
|
||||||
}
|
|
||||||
sh.HandleRPC(s.ctx, inHeader)
|
|
||||||
}
|
|
||||||
s.ctxDone = s.ctx.Done()
|
s.ctxDone = s.ctx.Done()
|
||||||
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
|
||||||
s.trReader = &transportReader{
|
s.trReader = &transportReader{
|
||||||
|
@ -629,7 +607,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
||||||
// HandleStreams receives incoming streams using the given handler. This is
|
// HandleStreams receives incoming streams using the given handler. This is
|
||||||
// typically run in a separate goroutine.
|
// typically run in a separate goroutine.
|
||||||
// traceCtx attaches trace to ctx and returns the new context.
|
// traceCtx attaches trace to ctx and returns the new context.
|
||||||
func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
|
||||||
defer close(t.readerDone)
|
defer close(t.readerDone)
|
||||||
for {
|
for {
|
||||||
t.controlBuf.throttle()
|
t.controlBuf.throttle()
|
||||||
|
@ -664,7 +642,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
|
||||||
}
|
}
|
||||||
switch frame := frame.(type) {
|
switch frame := frame.(type) {
|
||||||
case *http2.MetaHeadersFrame:
|
case *http2.MetaHeadersFrame:
|
||||||
if err := t.operateHeaders(frame, handle); err != nil {
|
if err := t.operateHeaders(ctx, frame, handle); err != nil {
|
||||||
t.Close(err)
|
t.Close(err)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -1242,10 +1220,6 @@ func (t *http2Server) Close(err error) {
|
||||||
for _, s := range streams {
|
for _, s := range streams {
|
||||||
s.cancel()
|
s.cancel()
|
||||||
}
|
}
|
||||||
for _, sh := range t.stats {
|
|
||||||
connEnd := &stats.ConnEnd{}
|
|
||||||
sh.HandleConn(t.ctx, connEnd)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteStream deletes the stream s from transport's active streams.
|
// deleteStream deletes the stream s from transport's active streams.
|
||||||
|
@ -1311,10 +1285,6 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Server) RemoteAddr() net.Addr {
|
|
||||||
return t.remoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *http2Server) Drain(debugData string) {
|
func (t *http2Server) Drain(debugData string) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
|
@ -1397,11 +1367,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
|
||||||
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
|
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
|
||||||
LocalFlowControlWindow: int64(t.fc.getSize()),
|
LocalFlowControlWindow: int64(t.fc.getSize()),
|
||||||
SocketOptions: channelz.GetSocketOption(t.conn),
|
SocketOptions: channelz.GetSocketOption(t.conn),
|
||||||
LocalAddr: t.localAddr,
|
LocalAddr: t.peer.LocalAddr,
|
||||||
RemoteAddr: t.remoteAddr,
|
RemoteAddr: t.peer.Addr,
|
||||||
// RemoteName :
|
// RemoteName :
|
||||||
}
|
}
|
||||||
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
|
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
|
||||||
s.Security = au.GetSecurityValue()
|
s.Security = au.GetSecurityValue()
|
||||||
}
|
}
|
||||||
s.RemoteFlowControlWindow = t.getOutFlowWindow()
|
s.RemoteFlowControlWindow = t.getOutFlowWindow()
|
||||||
|
@ -1433,10 +1403,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *http2Server) getPeer() *peer.Peer {
|
// Peer returns the peer of the transport.
|
||||||
|
func (t *http2Server) Peer() *peer.Peer {
|
||||||
return &peer.Peer{
|
return &peer.Peer{
|
||||||
Addr: t.remoteAddr,
|
Addr: t.peer.Addr,
|
||||||
AuthInfo: t.authInfo, // Can be nil
|
LocalAddr: t.peer.LocalAddr,
|
||||||
|
AuthInfo: t.peer.AuthInfo, // Can be nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1461,6 +1433,6 @@ func GetConnection(ctx context.Context) net.Conn {
|
||||||
// SetConnection adds the connection to the context to be able to get
|
// SetConnection adds the connection to the context to be able to get
|
||||||
// information about the destination ip and port for an incoming RPC. This also
|
// information about the destination ip and port for an incoming RPC. This also
|
||||||
// allows any unary or streaming interceptors to see the connection.
|
// allows any unary or streaming interceptors to see the connection.
|
||||||
func setConnection(ctx context.Context, conn net.Conn) context.Context {
|
func SetConnection(ctx context.Context, conn net.Conn) context.Context {
|
||||||
return context.WithValue(ctx, connectionKey{}, conn)
|
return context.WithValue(ctx, connectionKey{}, conn)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,7 @@ import (
|
||||||
"google.golang.org/grpc/internal/channelz"
|
"google.golang.org/grpc/internal/channelz"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/resolver"
|
"google.golang.org/grpc/resolver"
|
||||||
"google.golang.org/grpc/stats"
|
"google.golang.org/grpc/stats"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
@ -265,7 +266,8 @@ type Stream struct {
|
||||||
// headerValid indicates whether a valid header was received. Only
|
// headerValid indicates whether a valid header was received. Only
|
||||||
// meaningful after headerChan is closed (always call waitOnHeader() before
|
// meaningful after headerChan is closed (always call waitOnHeader() before
|
||||||
// reading its value). Not valid on server side.
|
// reading its value). Not valid on server side.
|
||||||
headerValid bool
|
headerValid bool
|
||||||
|
headerWireLength int // Only set on server side.
|
||||||
|
|
||||||
// hdrMu protects header and trailer metadata on the server-side.
|
// hdrMu protects header and trailer metadata on the server-side.
|
||||||
hdrMu sync.Mutex
|
hdrMu sync.Mutex
|
||||||
|
@ -425,6 +427,12 @@ func (s *Stream) Context() context.Context {
|
||||||
return s.ctx
|
return s.ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetContext sets the context of the stream. This will be deleted once the
|
||||||
|
// stats handler callouts all move to gRPC layer.
|
||||||
|
func (s *Stream) SetContext(ctx context.Context) {
|
||||||
|
s.ctx = ctx
|
||||||
|
}
|
||||||
|
|
||||||
// Method returns the method for the stream.
|
// Method returns the method for the stream.
|
||||||
func (s *Stream) Method() string {
|
func (s *Stream) Method() string {
|
||||||
return s.method
|
return s.method
|
||||||
|
@ -437,6 +445,12 @@ func (s *Stream) Status() *status.Status {
|
||||||
return s.status
|
return s.status
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HeaderWireLength returns the size of the headers of the stream as received
|
||||||
|
// from the wire. Valid only on the server.
|
||||||
|
func (s *Stream) HeaderWireLength() int {
|
||||||
|
return s.headerWireLength
|
||||||
|
}
|
||||||
|
|
||||||
// SetHeader sets the header metadata. This can be called multiple times.
|
// SetHeader sets the header metadata. This can be called multiple times.
|
||||||
// Server side only.
|
// Server side only.
|
||||||
// This should not be called in parallel to other data writes.
|
// This should not be called in parallel to other data writes.
|
||||||
|
@ -698,7 +712,7 @@ type ClientTransport interface {
|
||||||
// Write methods for a given Stream will be called serially.
|
// Write methods for a given Stream will be called serially.
|
||||||
type ServerTransport interface {
|
type ServerTransport interface {
|
||||||
// HandleStreams receives incoming streams using the given handler.
|
// HandleStreams receives incoming streams using the given handler.
|
||||||
HandleStreams(func(*Stream))
|
HandleStreams(context.Context, func(*Stream))
|
||||||
|
|
||||||
// WriteHeader sends the header metadata for the given stream.
|
// WriteHeader sends the header metadata for the given stream.
|
||||||
// WriteHeader may not be called on all streams.
|
// WriteHeader may not be called on all streams.
|
||||||
|
@ -717,8 +731,8 @@ type ServerTransport interface {
|
||||||
// handlers will be terminated asynchronously.
|
// handlers will be terminated asynchronously.
|
||||||
Close(err error)
|
Close(err error)
|
||||||
|
|
||||||
// RemoteAddr returns the remote network address.
|
// Peer returns the peer of the server transport.
|
||||||
RemoteAddr() net.Addr
|
Peer() *peer.Peer
|
||||||
|
|
||||||
// Drain notifies the client this ServerTransport stops accepting new RPCs.
|
// Drain notifies the client this ServerTransport stops accepting new RPCs.
|
||||||
Drain(debugData string)
|
Drain(debugData string)
|
||||||
|
|
|
@ -35,8 +35,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"google.golang.org/grpc/peer"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/hpack"
|
"golang.org/x/net/http2/hpack"
|
||||||
|
@ -356,19 +354,19 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
switch ht {
|
switch ht {
|
||||||
case notifyCall:
|
case notifyCall:
|
||||||
go transport.HandleStreams(h.handleStreamAndNotify)
|
go transport.HandleStreams(context.Background(), h.handleStreamAndNotify)
|
||||||
case suspended:
|
case suspended:
|
||||||
go transport.HandleStreams(func(*Stream) {})
|
go transport.HandleStreams(context.Background(), func(*Stream) {})
|
||||||
case misbehaved:
|
case misbehaved:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||||
go h.handleStreamMisbehave(t, s)
|
go h.handleStreamMisbehave(t, s)
|
||||||
})
|
})
|
||||||
case encodingRequiredStatus:
|
case encodingRequiredStatus:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||||
go h.handleStreamEncodingRequiredStatus(s)
|
go h.handleStreamEncodingRequiredStatus(s)
|
||||||
})
|
})
|
||||||
case invalidHeaderField:
|
case invalidHeaderField:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||||
go h.handleStreamInvalidHeaderField(s)
|
go h.handleStreamInvalidHeaderField(s)
|
||||||
})
|
})
|
||||||
case delayRead:
|
case delayRead:
|
||||||
|
@ -377,15 +375,15 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
close(s.ready)
|
close(s.ready)
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||||
go h.handleStreamDelayRead(t, s)
|
go h.handleStreamDelayRead(t, s)
|
||||||
})
|
})
|
||||||
case pingpong:
|
case pingpong:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||||
go h.handleStreamPingPong(t, s)
|
go h.handleStreamPingPong(t, s)
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
go transport.HandleStreams(func(s *Stream) {
|
go transport.HandleStreams(context.Background(), func(s *Stream) {
|
||||||
go h.handleStream(t, s)
|
go h.handleStream(t, s)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -2594,52 +2592,3 @@ func TestConnectionError_Unwrap(t *testing.T) {
|
||||||
t.Error("ConnectionError does not unwrap")
|
t.Error("ConnectionError does not unwrap")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s) TestPeerSetInServerContext(t *testing.T) {
|
|
||||||
// create client and server transports.
|
|
||||||
server, client, cancel := setUp(t, 0, normal)
|
|
||||||
defer cancel()
|
|
||||||
defer server.stop()
|
|
||||||
defer client.Close(fmt.Errorf("closed manually by test"))
|
|
||||||
|
|
||||||
// create a stream with client transport.
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
||||||
defer cancel()
|
|
||||||
stream, err := client.NewStream(ctx, &CallHdr{})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create a stream: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
waitWhileTrue(t, func() (bool, error) {
|
|
||||||
server.mu.Lock()
|
|
||||||
defer server.mu.Unlock()
|
|
||||||
|
|
||||||
if len(server.conns) == 0 {
|
|
||||||
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// verify peer is set in client transport context.
|
|
||||||
if _, ok := peer.FromContext(client.ctx); !ok {
|
|
||||||
t.Fatalf("Peer expected in client transport's context, but actually not found.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify peer is set in stream context.
|
|
||||||
if _, ok := peer.FromContext(stream.ctx); !ok {
|
|
||||||
t.Fatalf("Peer expected in stream context, but actually not found.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify peer is set in server transport context.
|
|
||||||
server.mu.Lock()
|
|
||||||
for k := range server.conns {
|
|
||||||
sc, ok := k.(*http2Server)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("ServerTransport is of type %T, want %T", k, &http2Server{})
|
|
||||||
}
|
|
||||||
if _, ok = peer.FromContext(sc.ctx); !ok {
|
|
||||||
t.Fatalf("Peer expected in server transport's context, but actually not found.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
server.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
|
@ -28,7 +28,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/authz/audit"
|
"google.golang.org/grpc/authz/audit"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
@ -38,6 +37,8 @@ import (
|
||||||
"google.golang.org/grpc/metadata"
|
"google.golang.org/grpc/metadata"
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = grpclog.Component("rbac")
|
var logger = grpclog.Component("rbac")
|
||||||
|
|
|
@ -32,6 +32,8 @@ import (
|
||||||
type Peer struct {
|
type Peer struct {
|
||||||
// Addr is the peer address.
|
// Addr is the peer address.
|
||||||
Addr net.Addr
|
Addr net.Addr
|
||||||
|
// LocalAddr is the local address.
|
||||||
|
LocalAddr net.Addr
|
||||||
// AuthInfo is the authentication information of the transport.
|
// AuthInfo is the authentication information of the transport.
|
||||||
// It is nil if there is no transport security being used.
|
// It is nil if there is no transport security being used.
|
||||||
AuthInfo credentials.AuthInfo
|
AuthInfo credentials.AuthInfo
|
||||||
|
|
45
server.go
45
server.go
|
@ -920,7 +920,7 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
s.serveStreams(st)
|
s.serveStreams(context.Background(), st, rawConn)
|
||||||
s.removeConn(lisAddr, st)
|
s.removeConn(lisAddr, st)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -974,12 +974,27 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
|
||||||
return st
|
return st
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) serveStreams(st transport.ServerTransport) {
|
func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) {
|
||||||
defer st.Close(errors.New("finished serving streams for the server transport"))
|
ctx = transport.SetConnection(ctx, rawConn)
|
||||||
var wg sync.WaitGroup
|
ctx = peer.NewContext(ctx, st.Peer())
|
||||||
|
for _, sh := range s.opts.statsHandlers {
|
||||||
|
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
|
||||||
|
RemoteAddr: st.Peer().Addr,
|
||||||
|
LocalAddr: st.Peer().LocalAddr,
|
||||||
|
})
|
||||||
|
sh.HandleConn(ctx, &stats.ConnBegin{})
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
st.Close(errors.New("finished serving streams for the server transport"))
|
||||||
|
for _, sh := range s.opts.statsHandlers {
|
||||||
|
sh.HandleConn(ctx, &stats.ConnEnd{})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
|
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
|
||||||
st.HandleStreams(func(stream *transport.Stream) {
|
st.HandleStreams(ctx, func(stream *transport.Stream) {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
streamQuota.acquire()
|
streamQuota.acquire()
|
||||||
|
@ -1043,7 +1058,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer s.removeConn(listenerAddressForServeHTTP, st)
|
defer s.removeConn(listenerAddressForServeHTTP, st)
|
||||||
s.serveStreams(st)
|
s.serveStreams(r.Context(), st, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
|
func (s *Server) addConn(addr string, st transport.ServerTransport) bool {
|
||||||
|
@ -1700,7 +1715,7 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
||||||
tr: tr,
|
tr: tr,
|
||||||
firstLine: firstLine{
|
firstLine: firstLine{
|
||||||
client: false,
|
client: false,
|
||||||
remoteAddr: t.RemoteAddr(),
|
remoteAddr: t.Peer().Addr,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if dl, ok := ctx.Deadline(); ok {
|
if dl, ok := ctx.Deadline(); ok {
|
||||||
|
@ -1734,6 +1749,22 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
||||||
service := sm[:pos]
|
service := sm[:pos]
|
||||||
method := sm[pos+1:]
|
method := sm[pos+1:]
|
||||||
|
|
||||||
|
md, _ := metadata.FromIncomingContext(ctx)
|
||||||
|
for _, sh := range s.opts.statsHandlers {
|
||||||
|
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
|
||||||
|
sh.HandleRPC(ctx, &stats.InHeader{
|
||||||
|
FullMethod: stream.Method(),
|
||||||
|
RemoteAddr: t.Peer().Addr,
|
||||||
|
LocalAddr: t.Peer().LocalAddr,
|
||||||
|
Compression: stream.RecvCompress(),
|
||||||
|
WireLength: stream.HeaderWireLength(),
|
||||||
|
Header: md,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// To have calls in stream callouts work. Will delete once all stats handler
|
||||||
|
// calls come from the gRPC layer.
|
||||||
|
stream.SetContext(ctx)
|
||||||
|
|
||||||
srv, knownService := s.services[service]
|
srv, knownService := s.services[service]
|
||||||
if knownService {
|
if knownService {
|
||||||
if md, ok := srv.methods[method]; ok {
|
if md, ok := srv.methods[method]; ok {
|
||||||
|
|
Loading…
Reference in New Issue