server: export ServerTransportStreamFromContext for unary interceptors to control headers/trailers (#2019)

This commit is contained in:
dfawley 2018-04-26 17:38:15 -07:00 committed by Menghan Li
parent 07709e8a3d
commit fc37cf1364
2 changed files with 11 additions and 9 deletions

View File

@ -1298,10 +1298,12 @@ type ServerTransportStream interface {
SetTrailer(md metadata.MD) error SetTrailer(md metadata.MD) error
} }
// serverStreamFromContext returns the server stream saved in ctx. Returns // ServerTransportStreamFromContext returns the ServerTransportStream saved in
// nil if the given context has no stream associated with it (which implies // ctx. Returns nil if the given context has no stream associated with it
// it is not an RPC invocation context). // (which implies it is not an RPC invocation context).
func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream { //
// This API is EXPERIMENTAL.
func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
s, _ := ctx.Value(streamKey{}).(ServerTransportStream) s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
return s return s
} }
@ -1438,7 +1440,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 { if md.Len() == 0 {
return nil return nil
} }
stream := serverTransportStreamFromContext(ctx) stream := ServerTransportStreamFromContext(ctx)
if stream == nil { if stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
@ -1448,7 +1450,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
// SendHeader sends header metadata. It may be called at most once. // SendHeader sends header metadata. It may be called at most once.
// The provided md and headers set by SetHeader() will be sent. // The provided md and headers set by SetHeader() will be sent.
func SendHeader(ctx context.Context, md metadata.MD) error { func SendHeader(ctx context.Context, md metadata.MD) error {
stream := serverTransportStreamFromContext(ctx) stream := ServerTransportStreamFromContext(ctx)
if stream == nil { if stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
@ -1464,7 +1466,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
if md.Len() == 0 { if md.Len() == 0 {
return nil return nil
} }
stream := serverTransportStreamFromContext(ctx) stream := ServerTransportStreamFromContext(ctx)
if stream == nil { if stream == nil {
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx) return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
} }
@ -1474,7 +1476,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
// Method returns the method string for the server context. The returned // Method returns the method string for the server context. The returned
// string is in the format of "/service/method". // string is in the format of "/service/method".
func Method(ctx context.Context) (string, bool) { func Method(ctx context.Context) (string, bool) {
s := serverTransportStreamFromContext(ctx) s := ServerTransportStreamFromContext(ctx)
if s == nil { if s == nil {
return "", false return "", false
} }

View File

@ -128,7 +128,7 @@ func TestGetServiceInfo(t *testing.T) {
func TestStreamContext(t *testing.T) { func TestStreamContext(t *testing.T) {
expectedStream := &transport.Stream{} expectedStream := &transport.Stream{}
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream) ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
s := serverTransportStreamFromContext(ctx) s := ServerTransportStreamFromContext(ctx)
stream, ok := s.(*transport.Stream) stream, ok := s.(*transport.Stream)
if !ok || expectedStream != stream { if !ok || expectedStream != stream {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream) t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)