mirror of https://github.com/grpc/grpc-go.git
server: export ServerTransportStreamFromContext for unary interceptors to control headers/trailers (#2019)
This commit is contained in:
parent
07709e8a3d
commit
fc37cf1364
18
server.go
18
server.go
|
@ -1298,10 +1298,12 @@ type ServerTransportStream interface {
|
|||
SetTrailer(md metadata.MD) error
|
||||
}
|
||||
|
||||
// serverStreamFromContext returns the server stream saved in ctx. Returns
|
||||
// nil if the given context has no stream associated with it (which implies
|
||||
// it is not an RPC invocation context).
|
||||
func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
||||
// ServerTransportStreamFromContext returns the ServerTransportStream saved in
|
||||
// ctx. Returns nil if the given context has no stream associated with it
|
||||
// (which implies it is not an RPC invocation context).
|
||||
//
|
||||
// This API is EXPERIMENTAL.
|
||||
func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
||||
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
|
||||
return s
|
||||
}
|
||||
|
@ -1438,7 +1440,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
|||
if md.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
stream := ServerTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
|
@ -1448,7 +1450,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
|||
// SendHeader sends header metadata. It may be called at most once.
|
||||
// The provided md and headers set by SetHeader() will be sent.
|
||||
func SendHeader(ctx context.Context, md metadata.MD) error {
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
stream := ServerTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
|
@ -1464,7 +1466,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
|||
if md.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
stream := serverTransportStreamFromContext(ctx)
|
||||
stream := ServerTransportStreamFromContext(ctx)
|
||||
if stream == nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||
}
|
||||
|
@ -1474,7 +1476,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
|||
// Method returns the method string for the server context. The returned
|
||||
// string is in the format of "/service/method".
|
||||
func Method(ctx context.Context) (string, bool) {
|
||||
s := serverTransportStreamFromContext(ctx)
|
||||
s := ServerTransportStreamFromContext(ctx)
|
||||
if s == nil {
|
||||
return "", false
|
||||
}
|
||||
|
|
|
@ -128,7 +128,7 @@ func TestGetServiceInfo(t *testing.T) {
|
|||
func TestStreamContext(t *testing.T) {
|
||||
expectedStream := &transport.Stream{}
|
||||
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
|
||||
s := serverTransportStreamFromContext(ctx)
|
||||
s := ServerTransportStreamFromContext(ctx)
|
||||
stream, ok := s.(*transport.Stream)
|
||||
if !ok || expectedStream != stream {
|
||||
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
|
||||
|
|
Loading…
Reference in New Issue