mirror of https://github.com/grpc/grpc-go.git
				
				
				
			Merge pull request #863 from menghanl/setTrailer
Allow multiple setTrailer
This commit is contained in:
		
						commit
						c2983be903
					
				| 
						 | 
				
			
			@ -886,8 +886,8 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
 | 
			
		||||
// It may be called at most once from a unary RPC handler. The ctx is the RPC
 | 
			
		||||
// handler's Context or one derived from it.
 | 
			
		||||
// When called more than once, all the provided metadata will be merged.
 | 
			
		||||
// The ctx is the RPC handler's Context or one derived from it.
 | 
			
		||||
func SetTrailer(ctx context.Context, md metadata.MD) error {
 | 
			
		||||
	if md.Len() == 0 {
 | 
			
		||||
		return nil
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -414,8 +414,8 @@ type ServerStream interface {
 | 
			
		|||
	// after SendProto. It fails if called multiple times or if
 | 
			
		||||
	// called after SendProto.
 | 
			
		||||
	SendHeader(metadata.MD) error
 | 
			
		||||
	// SetTrailer sets the trailer metadata which will be sent with the
 | 
			
		||||
	// RPC status.
 | 
			
		||||
	// SetTrailer sets the trailer metadata which will be sent with the RPC status.
 | 
			
		||||
	// When called more than once, all the provided metadata will be merged.
 | 
			
		||||
	SetTrailer(metadata.MD)
 | 
			
		||||
	Stream
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -79,6 +79,10 @@ var (
 | 
			
		|||
		"tkey1": []string{"trailerValue1"},
 | 
			
		||||
		"tkey2": []string{"trailerValue2"},
 | 
			
		||||
	}
 | 
			
		||||
	testTrailerMetadata2 = metadata.MD{
 | 
			
		||||
		"tkey1": []string{"trailerValue12"},
 | 
			
		||||
		"tkey2": []string{"trailerValue22"},
 | 
			
		||||
	}
 | 
			
		||||
	// capital "Key" is illegal in HTTP/2.
 | 
			
		||||
	malformedHTTP2Metadata = metadata.MD{
 | 
			
		||||
		"Key": []string{"foo"},
 | 
			
		||||
| 
						 | 
				
			
			@ -89,8 +93,9 @@ var (
 | 
			
		|||
var raceMode bool // set by race_test.go in race mode
 | 
			
		||||
 | 
			
		||||
type testServer struct {
 | 
			
		||||
	security  string // indicate the authentication protocol used by this server.
 | 
			
		||||
	earlyFail bool   // whether to error out the execution of a service handler prematurely.
 | 
			
		||||
	security           string // indicate the authentication protocol used by this server.
 | 
			
		||||
	earlyFail          bool   // whether to error out the execution of a service handler prematurely.
 | 
			
		||||
	multipleSetTrailer bool   // whether to call setTrailer multiple times.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
 | 
			
		||||
| 
						 | 
				
			
			@ -136,14 +141,21 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
 | 
			
		|||
		if err := grpc.SendHeader(ctx, md); err != nil {
 | 
			
		||||
			return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want %v", md, err, nil)
 | 
			
		||||
		}
 | 
			
		||||
		grpc.SetTrailer(ctx, testTrailerMetadata)
 | 
			
		||||
		if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
 | 
			
		||||
			return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
 | 
			
		||||
		}
 | 
			
		||||
		if s.multipleSetTrailer {
 | 
			
		||||
			if err := grpc.SetTrailer(ctx, testTrailerMetadata2); err != nil {
 | 
			
		||||
				return nil, grpc.Errorf(grpc.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata2, err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	pr, ok := peer.FromContext(ctx)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, fmt.Errorf("failed to get peer from ctx")
 | 
			
		||||
		return nil, grpc.Errorf(codes.DataLoss, "failed to get peer from ctx")
 | 
			
		||||
	}
 | 
			
		||||
	if pr.Addr == net.Addr(nil) {
 | 
			
		||||
		return nil, fmt.Errorf("failed to get peer address")
 | 
			
		||||
		return nil, grpc.Errorf(codes.DataLoss, "failed to get peer address")
 | 
			
		||||
	}
 | 
			
		||||
	if s.security != "" {
 | 
			
		||||
		// Check Auth info
 | 
			
		||||
| 
						 | 
				
			
			@ -153,13 +165,13 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
 | 
			
		|||
			authType = info.AuthType()
 | 
			
		||||
			serverName = info.State.ServerName
 | 
			
		||||
		default:
 | 
			
		||||
			return nil, fmt.Errorf("Unknown AuthInfo type")
 | 
			
		||||
			return nil, grpc.Errorf(codes.Unauthenticated, "Unknown AuthInfo type")
 | 
			
		||||
		}
 | 
			
		||||
		if authType != s.security {
 | 
			
		||||
			return nil, fmt.Errorf("Wrong auth type: got %q, want %q", authType, s.security)
 | 
			
		||||
			return nil, grpc.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security)
 | 
			
		||||
		}
 | 
			
		||||
		if serverName != "x.test.youtube.com" {
 | 
			
		||||
			return nil, fmt.Errorf("Unknown server name %q", serverName)
 | 
			
		||||
			return nil, grpc.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	// Simulate some service delay.
 | 
			
		||||
| 
						 | 
				
			
			@ -229,9 +241,12 @@ func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServ
 | 
			
		|||
	md, ok := metadata.FromContext(stream.Context())
 | 
			
		||||
	if ok {
 | 
			
		||||
		if err := stream.SendHeader(md); err != nil {
 | 
			
		||||
			return fmt.Errorf("%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
 | 
			
		||||
			return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
 | 
			
		||||
		}
 | 
			
		||||
		stream.SetTrailer(testTrailerMetadata)
 | 
			
		||||
		if s.multipleSetTrailer {
 | 
			
		||||
			stream.SetTrailer(testTrailerMetadata2)
 | 
			
		||||
		}
 | 
			
		||||
		stream.SetTrailer(md)
 | 
			
		||||
	}
 | 
			
		||||
	for {
 | 
			
		||||
		in, err := stream.Recv()
 | 
			
		||||
| 
						 | 
				
			
			@ -1193,6 +1208,76 @@ func testMetadataUnaryRPC(t *testing.T, e env) {
 | 
			
		|||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMultipleSetTrailerUnaryRPC(t *testing.T) {
 | 
			
		||||
	defer leakCheck(t)()
 | 
			
		||||
	for _, e := range listTestEnv() {
 | 
			
		||||
		testMultipleSetTrailerUnaryRPC(t, e)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) {
 | 
			
		||||
	te := newTest(t, e)
 | 
			
		||||
	te.startServer(&testServer{security: e.security, multipleSetTrailer: true})
 | 
			
		||||
	defer te.tearDown()
 | 
			
		||||
	tc := testpb.NewTestServiceClient(te.clientConn())
 | 
			
		||||
 | 
			
		||||
	const (
 | 
			
		||||
		argSize  = 1
 | 
			
		||||
		respSize = 1
 | 
			
		||||
	)
 | 
			
		||||
	payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req := &testpb.SimpleRequest{
 | 
			
		||||
		ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
 | 
			
		||||
		ResponseSize: proto.Int32(respSize),
 | 
			
		||||
		Payload:      payload,
 | 
			
		||||
	}
 | 
			
		||||
	var trailer metadata.MD
 | 
			
		||||
	ctx := metadata.NewContext(context.Background(), testMetadata)
 | 
			
		||||
	if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil {
 | 
			
		||||
		t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
 | 
			
		||||
	}
 | 
			
		||||
	expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2)
 | 
			
		||||
	if !reflect.DeepEqual(trailer, expectedTrailer) {
 | 
			
		||||
		t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestMultipleSetTrailerStreamingRPC(t *testing.T) {
 | 
			
		||||
	defer leakCheck(t)()
 | 
			
		||||
	for _, e := range listTestEnv() {
 | 
			
		||||
		testMultipleSetTrailerStreamingRPC(t, e)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) {
 | 
			
		||||
	te := newTest(t, e)
 | 
			
		||||
	te.startServer(&testServer{security: e.security, multipleSetTrailer: true})
 | 
			
		||||
	defer te.tearDown()
 | 
			
		||||
	tc := testpb.NewTestServiceClient(te.clientConn())
 | 
			
		||||
 | 
			
		||||
	ctx := metadata.NewContext(context.Background(), testMetadata)
 | 
			
		||||
	stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := stream.CloseSend(); err != nil {
 | 
			
		||||
		t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil)
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := stream.Recv(); err != io.EOF {
 | 
			
		||||
		t.Fatalf("%v failed to complele the FullDuplexCall: %v", stream, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	trailer := stream.Trailer()
 | 
			
		||||
	expectedTrailer := metadata.Join(testTrailerMetadata, testTrailerMetadata2)
 | 
			
		||||
	if !reflect.DeepEqual(trailer, expectedTrailer) {
 | 
			
		||||
		t.Fatalf("Received trailer metadata %v, want %v", trailer, expectedTrailer)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TestMalformedHTTP2Metedata verfies the returned error when the client
 | 
			
		||||
// sends an illegal metadata.
 | 
			
		||||
func TestMalformedHTTP2Metadata(t *testing.T) {
 | 
			
		||||
| 
						 | 
				
			
			@ -1601,8 +1686,8 @@ func testMetadataStreamingRPC(t *testing.T, e env) {
 | 
			
		|||
		}
 | 
			
		||||
	}
 | 
			
		||||
	trailerMD := stream.Trailer()
 | 
			
		||||
	if !reflect.DeepEqual(testMetadata, trailerMD) {
 | 
			
		||||
		t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata)
 | 
			
		||||
	if !reflect.DeepEqual(testTrailerMetadata, trailerMD) {
 | 
			
		||||
		t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testTrailerMetadata)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -39,7 +39,6 @@ package transport // import "google.golang.org/grpc/transport"
 | 
			
		|||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
| 
						 | 
				
			
			@ -287,19 +286,12 @@ func (s *Stream) StatusDesc() string {
 | 
			
		|||
	return s.statusDesc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrIllegalTrailerSet indicates that the trailer has already been set or it
 | 
			
		||||
// is too late to do so.
 | 
			
		||||
var ErrIllegalTrailerSet = errors.New("transport: trailer has been set")
 | 
			
		||||
 | 
			
		||||
// SetTrailer sets the trailer metadata which will be sent with the RPC status
 | 
			
		||||
// by the server. This can only be called at most once. Server side only.
 | 
			
		||||
// by the server. This can be called multiple times. Server side only.
 | 
			
		||||
func (s *Stream) SetTrailer(md metadata.MD) error {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
	if s.trailer != nil {
 | 
			
		||||
		return ErrIllegalTrailerSet
 | 
			
		||||
	}
 | 
			
		||||
	s.trailer = md.Copy()
 | 
			
		||||
	s.trailer = metadata.Join(s.trailer, md)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue