mirror of https://github.com/grpc/grpc-go.git
server: export ServerTransportStreamFromContext for unary interceptors to control headers/trailers (#2019)
This commit is contained in:
parent
07709e8a3d
commit
fc37cf1364
18
server.go
18
server.go
|
@ -1298,10 +1298,12 @@ type ServerTransportStream interface {
|
||||||
SetTrailer(md metadata.MD) error
|
SetTrailer(md metadata.MD) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// serverStreamFromContext returns the server stream saved in ctx. Returns
|
// ServerTransportStreamFromContext returns the ServerTransportStream saved in
|
||||||
// nil if the given context has no stream associated with it (which implies
|
// ctx. Returns nil if the given context has no stream associated with it
|
||||||
// it is not an RPC invocation context).
|
// (which implies it is not an RPC invocation context).
|
||||||
func serverTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
//
|
||||||
|
// This API is EXPERIMENTAL.
|
||||||
|
func ServerTransportStreamFromContext(ctx context.Context) ServerTransportStream {
|
||||||
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
|
s, _ := ctx.Value(streamKey{}).(ServerTransportStream)
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
@ -1438,7 +1440,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
||||||
if md.Len() == 0 {
|
if md.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stream := serverTransportStreamFromContext(ctx)
|
stream := ServerTransportStreamFromContext(ctx)
|
||||||
if stream == nil {
|
if stream == nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||||
}
|
}
|
||||||
|
@ -1448,7 +1450,7 @@ func SetHeader(ctx context.Context, md metadata.MD) error {
|
||||||
// SendHeader sends header metadata. It may be called at most once.
|
// SendHeader sends header metadata. It may be called at most once.
|
||||||
// The provided md and headers set by SetHeader() will be sent.
|
// The provided md and headers set by SetHeader() will be sent.
|
||||||
func SendHeader(ctx context.Context, md metadata.MD) error {
|
func SendHeader(ctx context.Context, md metadata.MD) error {
|
||||||
stream := serverTransportStreamFromContext(ctx)
|
stream := ServerTransportStreamFromContext(ctx)
|
||||||
if stream == nil {
|
if stream == nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||||
}
|
}
|
||||||
|
@ -1464,7 +1466,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
||||||
if md.Len() == 0 {
|
if md.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
stream := serverTransportStreamFromContext(ctx)
|
stream := ServerTransportStreamFromContext(ctx)
|
||||||
if stream == nil {
|
if stream == nil {
|
||||||
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
return status.Errorf(codes.Internal, "grpc: failed to fetch the stream from the context %v", ctx)
|
||||||
}
|
}
|
||||||
|
@ -1474,7 +1476,7 @@ func SetTrailer(ctx context.Context, md metadata.MD) error {
|
||||||
// Method returns the method string for the server context. The returned
|
// Method returns the method string for the server context. The returned
|
||||||
// string is in the format of "/service/method".
|
// string is in the format of "/service/method".
|
||||||
func Method(ctx context.Context) (string, bool) {
|
func Method(ctx context.Context) (string, bool) {
|
||||||
s := serverTransportStreamFromContext(ctx)
|
s := ServerTransportStreamFromContext(ctx)
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,7 +128,7 @@ func TestGetServiceInfo(t *testing.T) {
|
||||||
func TestStreamContext(t *testing.T) {
|
func TestStreamContext(t *testing.T) {
|
||||||
expectedStream := &transport.Stream{}
|
expectedStream := &transport.Stream{}
|
||||||
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
|
ctx := NewContextWithServerTransportStream(context.Background(), expectedStream)
|
||||||
s := serverTransportStreamFromContext(ctx)
|
s := ServerTransportStreamFromContext(ctx)
|
||||||
stream, ok := s.(*transport.Stream)
|
stream, ok := s.(*transport.Stream)
|
||||||
if !ok || expectedStream != stream {
|
if !ok || expectedStream != stream {
|
||||||
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
|
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream)
|
||||||
|
|
Loading…
Reference in New Issue