mirror of https://github.com/grpc/grpc-go.git
server: expose API to set send compressor (#5744)
Fixes https://github.com/grpc/grpc-go/issues/5792
This commit is contained in:
parent
a7058f7b72
commit
0954097276
|
@ -280,31 +280,36 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
|
|||
t.Errorf("stream method = %q; want %q", s.method, want)
|
||||
}
|
||||
|
||||
err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
|
||||
if err != nil {
|
||||
if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
|
||||
if err != nil {
|
||||
|
||||
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := s.SetSendCompress("gzip"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
md := metadata.Pairs("custom-header", "Another custom header value")
|
||||
err = s.SendHeader(md)
|
||||
delete(md, "custom-header")
|
||||
if err != nil {
|
||||
if err := s.SendHeader(md); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
delete(md, "custom-header")
|
||||
|
||||
err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
|
||||
if err == nil {
|
||||
if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
|
||||
t.Error("expected SetHeader call after SendHeader to fail")
|
||||
}
|
||||
err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
|
||||
if err == nil {
|
||||
|
||||
if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
|
||||
t.Error("expected second SendHeader call to fail")
|
||||
}
|
||||
|
||||
if err := s.SetSendCompress("snappy"); err == nil {
|
||||
t.Error("expected second SetSendCompress call to fail")
|
||||
}
|
||||
|
||||
st.bodyw.Close() // no body
|
||||
st.ht.WriteStatus(s, status.New(codes.OK, ""))
|
||||
}
|
||||
|
@ -317,6 +322,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
|
|||
"Content-Type": {"application/grpc"},
|
||||
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
|
||||
"Custom-Header": {"Custom header value", "Another custom header value"},
|
||||
"Grpc-Encoding": {"gzip"},
|
||||
}
|
||||
wantTrailer := http.Header{
|
||||
"Grpc-Status": {"0"},
|
||||
|
|
|
@ -404,6 +404,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
|
|||
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
|
||||
s.contentSubtype = contentSubtype
|
||||
isGRPC = true
|
||||
|
||||
case "grpc-accept-encoding":
|
||||
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
|
||||
if hf.Value == "" {
|
||||
continue
|
||||
}
|
||||
compressors := hf.Value
|
||||
if s.clientAdvertisedCompressors != "" {
|
||||
compressors = s.clientAdvertisedCompressors + "," + compressors
|
||||
}
|
||||
s.clientAdvertisedCompressors = compressors
|
||||
case "grpc-encoding":
|
||||
s.recvCompress = hf.Value
|
||||
case ":method":
|
||||
|
|
|
@ -257,6 +257,9 @@ type Stream struct {
|
|||
fc *inFlow
|
||||
wq *writeQuota
|
||||
|
||||
// Holds compressor names passed in grpc-accept-encoding metadata from the
|
||||
// client. This is empty for the client side stream.
|
||||
clientAdvertisedCompressors string
|
||||
// Callback to state application's intentions to read data. This
|
||||
// is used to adjust flow control, if needed.
|
||||
requestRead func(int)
|
||||
|
@ -345,8 +348,24 @@ func (s *Stream) RecvCompress() string {
|
|||
}
|
||||
|
||||
// SetSendCompress sets the compression algorithm to the stream.
|
||||
func (s *Stream) SetSendCompress(str string) {
|
||||
s.sendCompress = str
|
||||
func (s *Stream) SetSendCompress(name string) error {
|
||||
if s.isHeaderSent() || s.getState() == streamDone {
|
||||
return errors.New("transport: set send compressor called after headers sent or stream done")
|
||||
}
|
||||
|
||||
s.sendCompress = name
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendCompress returns the send compressor name.
|
||||
func (s *Stream) SendCompress() string {
|
||||
return s.sendCompress
|
||||
}
|
||||
|
||||
// ClientAdvertisedCompressors returns the compressor names advertised by the
|
||||
// client via grpc-accept-encoding header.
|
||||
func (s *Stream) ClientAdvertisedCompressors() string {
|
||||
return s.clientAdvertisedCompressors
|
||||
}
|
||||
|
||||
// Done returns a channel which is closed when it receives the final status
|
||||
|
|
100
server.go
100
server.go
|
@ -45,6 +45,7 @@ import (
|
|||
"google.golang.org/grpc/internal/channelz"
|
||||
"google.golang.org/grpc/internal/grpcrand"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
"google.golang.org/grpc/internal/grpcutil"
|
||||
"google.golang.org/grpc/internal/transport"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
@ -1263,6 +1264,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
var comp, decomp encoding.Compressor
|
||||
var cp Compressor
|
||||
var dc Decompressor
|
||||
var sendCompressorName string
|
||||
|
||||
// If dc is set and matches the stream's compression, use it. Otherwise, try
|
||||
// to find a matching registered compressor for decomp.
|
||||
|
@ -1283,12 +1285,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
|
||||
if s.opts.cp != nil {
|
||||
cp = s.opts.cp
|
||||
stream.SetSendCompress(cp.Type())
|
||||
sendCompressorName = cp.Type()
|
||||
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
|
||||
// Legacy compressor not specified; attempt to respond with same encoding.
|
||||
comp = encoding.GetCompressor(rc)
|
||||
if comp != nil {
|
||||
stream.SetSendCompress(rc)
|
||||
sendCompressorName = comp.Name()
|
||||
}
|
||||
}
|
||||
|
||||
if sendCompressorName != "" {
|
||||
if err := stream.SetSendCompress(sendCompressorName); err != nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1375,6 +1383,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
|
|||
}
|
||||
opts := &transport.Options{Last: true}
|
||||
|
||||
// Server handler could have set new compressor by calling SetSendCompressor.
|
||||
// In case it is set, we need to use it for compressing outbound message.
|
||||
if stream.SendCompress() != sendCompressorName {
|
||||
comp = encoding.GetCompressor(stream.SendCompress())
|
||||
}
|
||||
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
|
||||
if err == io.EOF {
|
||||
// The entire stream is done (for unary RPC only).
|
||||
|
@ -1597,12 +1610,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
|
|||
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
|
||||
if s.opts.cp != nil {
|
||||
ss.cp = s.opts.cp
|
||||
stream.SetSendCompress(s.opts.cp.Type())
|
||||
ss.sendCompressorName = s.opts.cp.Type()
|
||||
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
|
||||
// Legacy compressor not specified; attempt to respond with same encoding.
|
||||
ss.comp = encoding.GetCompressor(rc)
|
||||
if ss.comp != nil {
|
||||
stream.SetSendCompress(rc)
|
||||
ss.sendCompressorName = rc
|
||||
}
|
||||
}
|
||||
|
||||
if ss.sendCompressorName != "" {
|
||||
if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
|
||||
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1935,6 +1954,60 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetSendCompressor sets a compressor for outbound messages from the server.
|
||||
// It must not be called after any event that causes headers to be sent
|
||||
// (see ServerStream.SetHeader for the complete list). Provided compressor is
|
||||
// used when below conditions are met:
|
||||
//
|
||||
// - compressor is registered via encoding.RegisterCompressor
|
||||
// - compressor name must exist in the client advertised compressor names
|
||||
// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to
|
||||
// get client supported compressor names.
|
||||
//
|
||||
// The context provided must be the context passed to the server's handler.
|
||||
// It must be noted that compressor name encoding.Identity disables the
|
||||
// outbound compression.
|
||||
// By default, server messages will be sent using the same compressor with
|
||||
// which request messages were sent.
|
||||
//
|
||||
// It is not safe to call SetSendCompressor concurrently with SendHeader and
|
||||
// SendMsg.
|
||||
//
|
||||
// # Experimental
|
||||
//
|
||||
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
|
||||
// later release.
|
||||
func SetSendCompressor(ctx context.Context, name string) error {
|
||||
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
|
||||
if !ok || stream == nil {
|
||||
return fmt.Errorf("failed to fetch the stream from the given context")
|
||||
}
|
||||
|
||||
if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
|
||||
return fmt.Errorf("unable to set send compressor: %w", err)
|
||||
}
|
||||
|
||||
return stream.SetSendCompress(name)
|
||||
}
|
||||
|
||||
// ClientSupportedCompressors returns compressor names advertised by the client
|
||||
// via grpc-accept-encoding header.
|
||||
//
|
||||
// The context provided must be the context passed to the server's handler.
|
||||
//
|
||||
// # Experimental
|
||||
//
|
||||
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
|
||||
// later release.
|
||||
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
|
||||
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
|
||||
if !ok || stream == nil {
|
||||
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
|
||||
}
|
||||
|
||||
return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil
|
||||
}
|
||||
|
||||
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
|
||||
// When called more than once, all the provided metadata will be merged.
|
||||
//
|
||||
|
@ -1969,3 +2042,22 @@ type channelzServer struct {
|
|||
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
|
||||
return c.s.channelzMetric()
|
||||
}
|
||||
|
||||
// validateSendCompressor returns an error when given compressor name cannot be
|
||||
// handled by the server or the client based on the advertised compressors.
|
||||
func validateSendCompressor(name, clientCompressors string) error {
|
||||
if name == encoding.Identity {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !grpcutil.IsCompressorNameRegistered(name) {
|
||||
return fmt.Errorf("compressor not registered %q", name)
|
||||
}
|
||||
|
||||
for _, c := range strings.Split(clientCompressors, ",") {
|
||||
if c == name {
|
||||
return nil // found match
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("client does not support compressor %q", name)
|
||||
}
|
||||
|
|
|
@ -1511,6 +1511,8 @@ type serverStream struct {
|
|||
comp encoding.Compressor
|
||||
decomp encoding.Compressor
|
||||
|
||||
sendCompressorName string
|
||||
|
||||
maxReceiveMessageSize int
|
||||
maxSendMessageSize int
|
||||
trInfo *traceInfo
|
||||
|
@ -1603,6 +1605,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
|
|||
}
|
||||
}()
|
||||
|
||||
// Server handler could have set new compressor by calling SetSendCompressor.
|
||||
// In case it is set, we need to use it for compressing outbound message.
|
||||
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
|
||||
ss.comp = encoding.GetCompressor(sendCompressorsName)
|
||||
ss.sendCompressorName = sendCompressorsName
|
||||
}
|
||||
|
||||
// load hdr, payload, data
|
||||
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
|
||||
if err != nil {
|
||||
|
|
|
@ -59,6 +59,7 @@ import (
|
|||
"google.golang.org/grpc/internal"
|
||||
"google.golang.org/grpc/internal/binarylog"
|
||||
"google.golang.org/grpc/internal/channelz"
|
||||
"google.golang.org/grpc/internal/envconfig"
|
||||
"google.golang.org/grpc/internal/grpcsync"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/internal/stubserver"
|
||||
|
@ -5080,6 +5081,340 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// wrapCompressor is a wrapper of encoding.Compressor which maintains count of
|
||||
// Compressor method invokes.
|
||||
type wrapCompressor struct {
|
||||
encoding.Compressor
|
||||
compressInvokes int32
|
||||
}
|
||||
|
||||
func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
|
||||
atomic.AddInt32(&wc.compressInvokes, 1)
|
||||
return wc.Compressor.Compress(w)
|
||||
}
|
||||
|
||||
func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
|
||||
oldC := encoding.GetCompressor("gzip")
|
||||
c := &wrapCompressor{Compressor: oldC}
|
||||
encoding.RegisterCompressor(c)
|
||||
t.Cleanup(func() {
|
||||
encoding.RegisterCompressor(oldC)
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
func (s) TestSetSendCompressorSuccess(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
desc string
|
||||
dialOpts []grpc.DialOption
|
||||
resCompressor string
|
||||
wantCompressInvokes int32
|
||||
}{
|
||||
{
|
||||
name: "identity_request_and_gzip_response",
|
||||
desc: "request is uncompressed and response is gzip compressed",
|
||||
resCompressor: "gzip",
|
||||
wantCompressInvokes: 1,
|
||||
},
|
||||
{
|
||||
name: "gzip_request_and_identity_response",
|
||||
desc: "request is gzip compressed and response is uncompressed with identity",
|
||||
resCompressor: "identity",
|
||||
dialOpts: []grpc.DialOption{
|
||||
// Use WithCompressor instead of UseCompressor to avoid counting
|
||||
// the client's compressor usage.
|
||||
grpc.WithCompressor(grpc.NewGZIPCompressor()),
|
||||
},
|
||||
wantCompressInvokes: 0,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Run("unary", func(t *testing.T) {
|
||||
testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
|
||||
})
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
|
||||
wc := setupGzipWrapCompressor(t)
|
||||
ss := &stubserver.StubServer{
|
||||
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil, dialOpts...); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
|
||||
if compressInvokes != wantCompressInvokes {
|
||||
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
|
||||
}
|
||||
}
|
||||
|
||||
func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
|
||||
wc := setupGzipWrapCompressor(t)
|
||||
ss := &stubserver.StubServer{
|
||||
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.Send(&testpb.StreamingOutputCallResponse{})
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil, dialOpts...); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
s, err := ss.Client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
|
||||
t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
if _, err := s.Recv(); err != nil {
|
||||
t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
|
||||
if compressInvokes != wantCompressInvokes {
|
||||
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) {
|
||||
resCompressor := "snappy2"
|
||||
wantErr := status.Error(codes.Unknown, "unable to set send compressor: compressor not registered \"snappy2\"")
|
||||
|
||||
t.Run("unary", func(t *testing.T) {
|
||||
testUnarySetSendCompressorFailure(t, resCompressor, wantErr)
|
||||
})
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
testStreamSetSendCompressorFailure(t, resCompressor, wantErr)
|
||||
})
|
||||
}
|
||||
|
||||
func (s) TestUnadvertisedSetSendCompressorFailure(t *testing.T) {
|
||||
// Disable client compressor advertisement.
|
||||
defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors)
|
||||
envconfig.AdvertiseCompressors = false
|
||||
|
||||
resCompressor := "gzip"
|
||||
wantErr := status.Error(codes.Unknown, "unable to set send compressor: client does not support compressor \"gzip\"")
|
||||
|
||||
t.Run("unary", func(t *testing.T) {
|
||||
testUnarySetSendCompressorFailure(t, resCompressor, wantErr)
|
||||
})
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
testStreamSetSendCompressorFailure(t, resCompressor, wantErr)
|
||||
})
|
||||
}
|
||||
|
||||
func testUnarySetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
|
||||
ss := &stubserver.StubServer{
|
||||
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
|
||||
t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func testStreamSetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
|
||||
ss := &stubserver.StubServer{
|
||||
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.Send(&testpb.StreamingOutputCallResponse{})
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v, want: nil", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
s, err := ss.Client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
|
||||
t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
if _, err := s.Recv(); !equalError(err, wantErr) {
|
||||
t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) {
|
||||
ss := &stubserver.StubServer{
|
||||
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
// Send headers early and then set send compressor.
|
||||
grpc.SendHeader(ctx, metadata.MD{})
|
||||
err := grpc.SetSendCompressor(ctx, "gzip")
|
||||
if err == nil {
|
||||
t.Error("Wanted set send compressor error")
|
||||
return &testpb.Empty{}, nil
|
||||
}
|
||||
return nil, err
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
|
||||
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
|
||||
t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) {
|
||||
ss := &stubserver.StubServer{
|
||||
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
|
||||
// Send headers early and then set send compressor.
|
||||
grpc.SendHeader(stream.Context(), metadata.MD{})
|
||||
err := grpc.SetSendCompressor(stream.Context(), "gzip")
|
||||
if err == nil {
|
||||
t.Error("Wanted set send compressor error")
|
||||
}
|
||||
return err
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
|
||||
s, err := ss.Client.FullDuplexCall(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
|
||||
}
|
||||
|
||||
if _, err := s.Recv(); !equalError(err, wantErr) {
|
||||
t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestClientSupportedCompressors(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
desc string
|
||||
ctx context.Context
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
desc: "No additional grpc-accept-encoding header",
|
||||
ctx: context.Background(),
|
||||
want: []string{"gzip"},
|
||||
},
|
||||
{
|
||||
desc: "With additional grpc-accept-encoding header",
|
||||
ctx: metadata.AppendToOutgoingContext(context.Background(),
|
||||
"grpc-accept-encoding", "test-compressor-1",
|
||||
"grpc-accept-encoding", "test-compressor-2",
|
||||
),
|
||||
want: []string{"gzip", "test-compressor-1", "test-compressor-2"},
|
||||
},
|
||||
{
|
||||
desc: "With additional empty grpc-accept-encoding header",
|
||||
ctx: metadata.AppendToOutgoingContext(context.Background(),
|
||||
"grpc-accept-encoding", "",
|
||||
),
|
||||
want: []string{"gzip"},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
ss := &stubserver.StubServer{
|
||||
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||
got, err := grpc.ClientSupportedCompressors(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("unexpected client compressors got: %v, want: %v", got, tt.want)
|
||||
}
|
||||
|
||||
return &testpb.Empty{}, nil
|
||||
},
|
||||
}
|
||||
if err := ss.Start(nil); err != nil {
|
||||
t.Fatalf("Error starting endpoint server: %v, want: nil", err)
|
||||
}
|
||||
defer ss.Stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(tt.ctx, defaultTestTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
|
||||
const mdkey = "somedata"
|
||||
|
||||
|
|
Loading…
Reference in New Issue