mirror of https://github.com/grpc/grpc-go.git
Allow storing alternate transport.ServerStream implementations in context (#1904)
This commit is contained in:
parent
031ee13cfe
commit
57640c0e6f
58
server.go
58
server.go
|
@ -919,7 +919,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
return nil
|
||||
}
|
||||
reply, appErr := md.Handler(srv.server, stream.Context(), df, s.opts.unaryInt)
|
||||
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
|
||||
reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
|
||||
if appErr != nil {
|
||||
appStatus, ok := status.FromError(appErr)
|
||||
if !ok {
|
||||
|
@ -995,7 +996,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
sh.HandleRPC(stream.Context(), end)
|
||||
}()
|
||||
}
|
||||
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
|
||||
ss := &serverStream{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
s: stream,
|
||||
p: &parser{r: stream},
|
||||
|
@ -1089,7 +1092,6 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
ss.mu.Unlock()
|
||||
}
|
||||
return t.WriteStatus(ss.s, status.New(codes.OK, ""))
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
|
||||
|
@ -1171,6 +1173,40 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
|
|||
}
|
||||
}
|
||||
|
||||
// The key to save ServerTransportStream in the context.
|
||||
type streamKey struct{}
|
||||
|
||||
// NewContextWithServerTransportStream creates a new context from ctx and
|
||||
// attaches stream to it.
|
||||
//
|
||||
// This API is EXPERIMENTAL.
|
||||
func NewContextWithServerTransportStream(ctx context.Context, stream ServerTransportStream) context.Context {
|
||||
return context.WithValue(ctx, streamKey{}, stream)
|
||||
}
|
||||
|
||||
// ServerTransportStream is a minimal interface that a transport stream must
|
||||
// implement. This can be used to mock an actual transport stream for tests of
|
||||
// handler code that use, for example, grpc.SetHeader (which requires some
|
||||
// stream to be in context).
|
||||
//
|
||||
// See also NewContextWithServerTransportStream.
|
||||
//
|
||||
// This API is EXPERIMENTAL.
|
||||
type ServerTransportStream interface {
|
||||
Method() string
|
||||
SetHeader(md metadata.MD) error
|
||||
SendHeader(md metadata.MD) error
|
||||
SetTrailer(md metadata.MD) error
|
||||
}
|
||||
|
||||
// serverStreamFromContext returns the server stream saved in ctx. Returns
|
||||
// nil if the given context has no stream associated with it (which implies
|
||||
// it is not an RPC invocation context).
|
||||
func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
||||
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
|
||||
return s
|
||||
}
|
||||
|
||||
// Stop stops the gRPC server. It immediately closes all open
|
||||
// connections and listeners.
|
||||
// It cancels all active RPCs on the server side and the corresponding
|
||||
|
@ -1291,8 +1327,8 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
|||
if md.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
stream, ok := transport.StreamFromContext(ctx)
|
||||
if !ok {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
return stream.SetHeader(md)
|
||||
|
@ -1301,15 +1337,11 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
|||
// SendHeader sends header metadata. It may be called at most once.
|
||||
// The provided md and headers set by SetHeader() will be sent.
|
||||
func SendHeader(ctx context.Context, md metadata.MD) error {
|
||||
stream, ok := transport.StreamFromContext(ctx)
|
||||
if !ok {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
t := stream.ServerTransport()
|
||||
if t == nil {
|
||||
grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream)
|
||||
}
|
||||
if err := t.WriteHeader(stream, md); err != nil {
|
||||
if err := stream.SendHeader(md); err != nil {
|
||||
return toRPCErr(err)
|
||||
}
|
||||
return nil
|
||||
|
@ -1321,8 +1353,8 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
|||
if md.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
stream, ok := transport.StreamFromContext(ctx)
|
||||
if !ok {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
return stream.SetTrailer(md)
|
||||
|
|
|
@ -25,7 +25,9 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc/test/leakcheck"
|
||||
"google.golang.org/grpc/transport"
|
||||
)
|
||||
|
||||
type emptyServiceServer interface{}
|
||||
|
@ -122,3 +124,13 @@ func TestGetServiceInfo(t *testing.T) {
|
|||
t.Errorf("GetServiceInfo() = %+v, want %+v", info, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamContext(t *testing.T) {
|
||||
expectedStream := &transport.Stream{}
|
||||
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
|
||||
s := serverTransportStreamFromContext(ctx)
|
||||
stream, ok := s.(*transport.Stream)
|
||||
if !ok || expectedStream != stream {
|
||||
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
|
||||
}
|
||||
}
|
||||
|
|
11
stream.go
11
stream.go
|
@ -608,6 +608,7 @@ type ServerStream interface {
|
|||
|
||||
// serverStream implements a server side Stream.
|
||||
type serverStream struct {
|
||||
ctx context.Context
|
||||
t transport.ServerTransport
|
||||
s *transport.Stream
|
||||
p *parser
|
||||
|
@ -628,7 +629,7 @@ type serverStream struct {
|
|||
}
|
||||
|
||||
func (ss *serverStream) Context() context.Context {
|
||||
return ss.s.Context()
|
||||
return ss.ctx
|
||||
}
|
||||
|
||||
func (ss *serverStream) SetHeader(md metadata.MD) error {
|
||||
|
@ -731,9 +732,9 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
|
|||
// MethodFromServerStream returns the method string for the input stream.
|
||||
// The returned string is in the format of "/service/method".
|
||||
func MethodFromServerStream(stream ServerStream) (string, bool) {
|
||||
s, ok := transport.StreamFromContext(stream.Context())
|
||||
if !ok {
|
||||
return "", ok
|
||||
s := serverTransportStreamFromContext(stream.Context())
|
||||
if s == nil {
|
||||
return "", false
|
||||
}
|
||||
return s.Method(), ok
|
||||
return s.Method(), true
|
||||
}
|
||||
|
|
|
@ -354,8 +354,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace
|
|||
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS}
|
||||
}
|
||||
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
|
||||
ctx = peer.NewContext(ctx, pr)
|
||||
s.ctx = newContextWithStream(ctx, s)
|
||||
s.ctx = peer.NewContext(ctx, pr)
|
||||
if ht.stats != nil {
|
||||
s.ctx = ht.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
|
||||
inHeader := &stats.InHeader{
|
||||
|
|
|
@ -307,10 +307,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
pr.AuthInfo = t.authInfo
|
||||
}
|
||||
s.ctx = peer.NewContext(s.ctx, pr)
|
||||
// Cache the current stream to the context so that the server application
|
||||
// can find out. Required when the server wants to send some metadata
|
||||
// back to the client (unary call only).
|
||||
s.ctx = newContextWithStream(s.ctx, s)
|
||||
// Attach the received metadata to the context.
|
||||
if len(state.mdata) > 0 {
|
||||
s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata)
|
||||
|
|
|
@ -366,6 +366,14 @@ func (s *Stream) SetHeader(md metadata.MD) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SendHeader sends the given header metadata. The given metadata is
|
||||
// combined with any metadata set by previous calls to SetHeader and
|
||||
// then written to the transport stream.
|
||||
func (s *Stream) SendHeader(md metadata.MD) error {
|
||||
t := s.ServerTransport()
|
||||
return t.WriteHeader(s, md)
|
||||
}
|
||||
|
||||
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
||||
// by the server. This can be called multiple times. Server side only.
|
||||
func (s *Stream) SetTrailer(md metadata.MD) error {
|
||||
|
@ -445,21 +453,6 @@ func (s *Stream) GoString() string {
|
|||
return fmt.Sprintf("<stream: %p, %v>", s, s.method)
|
||||
}
|
||||
|
||||
// The key to save transport.Stream in the context.
|
||||
type streamKey struct{}
|
||||
|
||||
// newContextWithStream creates a new context from ctx and attaches stream
|
||||
// to it.
|
||||
func newContextWithStream(ctx context.Context, stream *Stream) context.Context {
|
||||
return context.WithValue(ctx, streamKey{}, stream)
|
||||
}
|
||||
|
||||
// StreamFromContext returns the stream saved in ctx.
|
||||
func StreamFromContext(ctx context.Context) (s *Stream, ok bool) {
|
||||
s, ok = ctx.Value(streamKey{}).(*Stream)
|
||||
return
|
||||
}
|
||||
|
||||
// state of transport
|
||||
type transportState int
|
||||
|
||||
|
|
|
@ -1552,15 +1552,6 @@ func TestInvalidHeaderField(t *testing.T) {
|
|||
server.stop()
|
||||
}
|
||||
|
||||
func TestStreamContext(t *testing.T) {
|
||||
expectedStream := &Stream{}
|
||||
ctx := newContextWithStream(context.Background(), expectedStream)
|
||||
s, ok := StreamFromContext(ctx)
|
||||
if !ok || expectedStream != s {
|
||||
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsReservedHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
h string
|
||||
|
|
Loading…
Reference in New Issue