mirror of https://github.com/grpc/grpc-go.git
Merge pull request #863 from menghanl/setTrailer
Allow multiple setTrailer
This commit is contained in:
commit
c2983be903
|
|
@ -886,8 +886,8 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
|
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
|
||||||
// It may be called at most once from a unary RPC handler. The ctx is the RPC
|
// When called more than once, all the provided metadata will be merged.
|
||||||
// handler's Context or one derived from it.
|
// The ctx is the RPC handler's Context or one derived from it.
|
||||||
func SetTrailer(ctx context.Context, md metadata.MD) error {
|
func SetTrailer(ctx context.Context, md metadata.MD) error {
|
||||||
if md.Len() == 0 {
|
if md.Len() == 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -414,8 +414,8 @@ type ServerStream interface {
|
||||||
// after SendProto. It fails if called multiple times or if
|
// after SendProto. It fails if called multiple times or if
|
||||||
// called after SendProto.
|
// called after SendProto.
|
||||||
SendHeader(metadata.MD) error
|
SendHeader(metadata.MD) error
|
||||||
// SetTrailer sets the trailer metadata which will be sent with the
|
// SetTrailer sets the trailer metadata which will be sent with the RPC status.
|
||||||
// RPC status.
|
// When called more than once, all the provided metadata will be merged.
|
||||||
SetTrailer(metadata.MD)
|
SetTrailer(metadata.MD)
|
||||||
Stream
|
Stream
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,10 @@ var (
|
||||||
"tkey1": []string{"trailerValue1"},
|
"tkey1": []string{"trailerValue1"},
|
||||||
"tkey2": []string{"trailerValue2"},
|
"tkey2": []string{"trailerValue2"},
|
||||||
}
|
}
|
||||||
|
testTrailerMetadata2 = metadata.MD{
|
||||||
|
"tkey1": []string{"trailerValue12"},
|
||||||
|
"tkey2": []string{"trailerValue22"},
|
||||||
|
}
|
||||||
// capital "Key" is illegal in HTTP/2.
|
// capital "Key" is illegal in HTTP/2.
|
||||||
malformedHTTP2Metadata = metadata.MD{
|
malformedHTTP2Metadata = metadata.MD{
|
||||||
"Key": []string{"foo"},
|
"Key": []string{"foo"},
|
||||||
|
|
@ -89,8 +93,9 @@ var (
|
||||||
var raceMode bool // set by race_test.go in race mode
|
var raceMode bool // set by race_test.go in race mode
|
||||||
|
|
||||||
type testServer struct {
|
type testServer struct {
|
||||||
security string // indicate the authentication protocol used by this server.
|
security string // indicate the authentication protocol used by this server.
|
||||||
earlyFail bool // whether to error out the execution of a service handler prematurely.
|
earlyFail bool // whether to error out the execution of a service handler prematurely.
|
||||||
|
multipleSetTrailer bool // whether to call setTrailer multiple times.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
|
@ -136,14 +141,21 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
||||||
if err := grpc.SendHeader(ctx, md); err != nil {
|
if err := grpc.SendHeader(ctx, md); err != nil {
|
||||||
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil)
|
return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil)
|
||||||
}
|
}
|
||||||
grpc.SetTrailer(ctx, testTrailerMetadata)
|
if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
|
||||||
|
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
|
||||||
|
}
|
||||||
|
if s.multipleSetTrailer {
|
||||||
|
if err := grpc.SetTrailer(ctx, testTrailerMetadata2); err != nil {
|
||||||
|
return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata2, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
pr, ok := peer.FromContext(ctx)
|
pr, ok := peer.FromContext(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("failed to get peer from ctx")
|
return nil, grpc.Errorf(codes.DataLoss, "failed to get peer from ctx")
|
||||||
}
|
}
|
||||||
if pr.Addr == net.Addr(nil) {
|
if pr.Addr == net.Addr(nil) {
|
||||||
return nil, fmt.Errorf("failed to get peer address")
|
return nil, grpc.Errorf(codes.DataLoss, "failed to get peer address")
|
||||||
}
|
}
|
||||||
if s.security != "" {
|
if s.security != "" {
|
||||||
// Check Auth info
|
// Check Auth info
|
||||||
|
|
@ -153,13 +165,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
|
||||||
authType = info.AuthType()
|
authType = info.AuthType()
|
||||||
serverName = info.State.ServerName
|
serverName = info.State.ServerName
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknown AuthInfo type")
|
return nil, grpc.Errorf(codes.Unauthenticated, "Unknown AuthInfo type")
|
||||||
}
|
}
|
||||||
if authType != s.security {
|
if authType != s.security {
|
||||||
return nil, fmt.Errorf("Wrong auth type: got %q, want %q", authType, s.security)
|
return nil, grpc.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security)
|
||||||
}
|
}
|
||||||
if serverName != "x.test.youtube.com" {
|
if serverName != "x.test.youtube.com" {
|
||||||
return nil, fmt.Errorf("Unknown server name %q", serverName)
|
return nil, grpc.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Simulate some service delay.
|
// Simulate some service delay.
|
||||||
|
|
@ -229,9 +241,12 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ
|
||||||
md, ok := metadata.FromContext(stream.Context())
|
md, ok := metadata.FromContext(stream.Context())
|
||||||
if ok {
|
if ok {
|
||||||
if err := stream.SendHeader(md); err != nil {
|
if err := stream.SendHeader(md); err != nil {
|
||||||
return fmt.Errorf("%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
|
return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
|
||||||
|
}
|
||||||
|
stream.SetTrailer(testTrailerMetadata)
|
||||||
|
if s.multipleSetTrailer {
|
||||||
|
stream.SetTrailer(testTrailerMetadata2)
|
||||||
}
|
}
|
||||||
stream.SetTrailer(md)
|
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
in, err := stream.Recv()
|
in, err := stream.Recv()
|
||||||
|
|
@ -1193,6 +1208,76 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMultipleSetTrailerUnaryRPC(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testMultipleSetTrailerUnaryRPC(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.startServer(&testServer{security: e.security, multipleSetTrailer: true})
|
||||||
|
defer te.tearDown()
|
||||||
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
|
||||||
|
const (
|
||||||
|
argSize = 1
|
||||||
|
respSize = 1
|
||||||
|
)
|
||||||
|
payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &testpb.SimpleRequest{
|
||||||
|
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
|
||||||
|
ResponseSize: proto.Int32(respSize),
|
||||||
|
Payload: payload,
|
||||||
|
}
|
||||||
|
var trailer metadata.MD
|
||||||
|
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||||
|
if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
|
||||||
|
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
|
||||||
|
}
|
||||||
|
expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2)
|
||||||
|
if !reflect.DeepEqual(trailer, expectedTrailer) {
|
||||||
|
t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleSetTrailerStreamingRPC(t *testing.T) {
|
||||||
|
defer leakCheck(t)()
|
||||||
|
for _, e := range listTestEnv() {
|
||||||
|
testMultipleSetTrailerStreamingRPC(t, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) {
|
||||||
|
te := newTest(t, e)
|
||||||
|
te.startServer(&testServer{security: e.security, multipleSetTrailer: true})
|
||||||
|
defer te.tearDown()
|
||||||
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
|
||||||
|
ctx := metadata.NewContext(context.Background(), testMetadata)
|
||||||
|
stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
|
||||||
|
}
|
||||||
|
if err := stream.CloseSend(); err != nil {
|
||||||
|
t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
|
||||||
|
}
|
||||||
|
if _, err := stream.Recv(); err != io.EOF {
|
||||||
|
t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
trailer := stream.Trailer()
|
||||||
|
expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2)
|
||||||
|
if !reflect.DeepEqual(trailer, expectedTrailer) {
|
||||||
|
t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestMalformedHTTP2Metedata verfies the returned error when the client
|
// TestMalformedHTTP2Metedata verfies the returned error when the client
|
||||||
// sends an illegal metadata.
|
// sends an illegal metadata.
|
||||||
func TestMalformedHTTP2Metadata(t *testing.T) {
|
func TestMalformedHTTP2Metadata(t *testing.T) {
|
||||||
|
|
@ -1601,8 +1686,8 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
trailerMD := stream.Trailer()
|
trailerMD := stream.Trailer()
|
||||||
if !reflect.DeepEqual(testMetadata, trailerMD) {
|
if !reflect.DeepEqual(testTrailerMetadata, trailerMD) {
|
||||||
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
|
t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testTrailerMetadata)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,6 @@ package transport // import "google.golang.org/grpc/transport"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -287,19 +286,12 @@ func (s *Stream) StatusDesc() string {
|
||||||
return s.statusDesc
|
return s.statusDesc
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrIllegalTrailerSet indicates that the trailer has already been set or it
|
|
||||||
// is too late to do so.
|
|
||||||
var ErrIllegalTrailerSet = errors.New("transport: trailer has been set")
|
|
||||||
|
|
||||||
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
// SetTrailer sets the trailer metadata which will be sent with the RPC status
|
||||||
// by the server. This can only be called at most once. Server side only.
|
// by the server. This can be called multiple times. Server side only.
|
||||||
func (s *Stream) SetTrailer(md metadata.MD) error {
|
func (s *Stream) SetTrailer(md metadata.MD) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
if s.trailer != nil {
|
s.trailer = metadata.Join(s.trailer, md)
|
||||||
return ErrIllegalTrailerSet
|
|
||||||
}
|
|
||||||
s.trailer = md.Copy()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue