server: expose API to set send compressor (#5744)

Fixes https://github.com/grpc/grpc-go/issues/5792
This commit is contained in:
Ronak Jain 2023-02-01 02:57:34 +05:30 committed by GitHub
parent a7058f7b72
commit 0954097276
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 489 additions and 17 deletions

View File

@ -280,31 +280,36 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
t.Errorf("stream method = %q; want %q", s.method, want)
}
err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
if err != nil {
if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
t.Error(err)
}
err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
if err != nil {
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}
if err := s.SetSendCompress("gzip"); err != nil {
t.Error(err)
}
md := metadata.Pairs("custom-header", "Another custom header value")
err = s.SendHeader(md)
delete(md, "custom-header")
if err != nil {
if err := s.SendHeader(md); err != nil {
t.Error(err)
}
delete(md, "custom-header")
err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
if err == nil {
if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
t.Error("expected SetHeader call after SendHeader to fail")
}
err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
if err == nil {
if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
t.Error("expected second SendHeader call to fail")
}
if err := s.SetSendCompress("snappy"); err == nil {
t.Error("expected second SetSendCompress call to fail")
}
st.bodyw.Close() // no body
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
@ -317,6 +322,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
"Custom-Header": {"Custom header value", "Another custom header value"},
"Grpc-Encoding": {"gzip"},
}
wantTrailer := http.Header{
"Grpc-Status": {"0"},

View File

@ -404,6 +404,17 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
s.contentSubtype = contentSubtype
isGRPC = true
case "grpc-accept-encoding":
mdata[hf.Name] = append(mdata[hf.Name], hf.Value)
if hf.Value == "" {
continue
}
compressors := hf.Value
if s.clientAdvertisedCompressors != "" {
compressors = s.clientAdvertisedCompressors + "," + compressors
}
s.clientAdvertisedCompressors = compressors
case "grpc-encoding":
s.recvCompress = hf.Value
case ":method":

View File

@ -257,6 +257,9 @@ type Stream struct {
fc *inFlow
wq *writeQuota
// Holds compressor names passed in grpc-accept-encoding metadata from the
// client. This is empty for the client side stream.
clientAdvertisedCompressors string
// Callback to state application's intentions to read data. This
// is used to adjust flow control, if needed.
requestRead func(int)
@ -345,8 +348,24 @@ func (s *Stream) RecvCompress() string {
}
// SetSendCompress sets the compression algorithm to the stream.
func (s *Stream) SetSendCompress(str string) {
s.sendCompress = str
func (s *Stream) SetSendCompress(name string) error {
if s.isHeaderSent() || s.getState() == streamDone {
return errors.New("transport: set send compressor called after headers sent or stream done")
}
s.sendCompress = name
return nil
}
// SendCompress returns the send compressor name.
func (s *Stream) SendCompress() string {
return s.sendCompress
}
// ClientAdvertisedCompressors returns the compressor names advertised by the
// client via grpc-accept-encoding header.
func (s *Stream) ClientAdvertisedCompressors() string {
return s.clientAdvertisedCompressors
}
// Done returns a channel which is closed when it receives the final status

100
server.go
View File

@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
@ -1263,6 +1264,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
var comp, decomp encoding.Compressor
var cp Compressor
var dc Decompressor
var sendCompressorName string
// If dc is set and matches the stream's compression, use it. Otherwise, try
// to find a matching registered compressor for decomp.
@ -1283,12 +1285,18 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
cp = s.opts.cp
stream.SetSendCompress(cp.Type())
sendCompressorName = cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
comp = encoding.GetCompressor(rc)
if comp != nil {
stream.SetSendCompress(rc)
sendCompressorName = comp.Name()
}
}
if sendCompressorName != "" {
if err := stream.SetSendCompress(sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}
@ -1375,6 +1383,11 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
opts := &transport.Options{Last: true}
// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if stream.SendCompress() != sendCompressorName {
comp = encoding.GetCompressor(stream.SendCompress())
}
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
if err == io.EOF {
// The entire stream is done (for unary RPC only).
@ -1597,12 +1610,18 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
// NOTE: this needs to be ahead of all handling, https://github.com/grpc/grpc-go/issues/686.
if s.opts.cp != nil {
ss.cp = s.opts.cp
stream.SetSendCompress(s.opts.cp.Type())
ss.sendCompressorName = s.opts.cp.Type()
} else if rc := stream.RecvCompress(); rc != "" && rc != encoding.Identity {
// Legacy compressor not specified; attempt to respond with same encoding.
ss.comp = encoding.GetCompressor(rc)
if ss.comp != nil {
stream.SetSendCompress(rc)
ss.sendCompressorName = rc
}
}
if ss.sendCompressorName != "" {
if err := stream.SetSendCompress(ss.sendCompressorName); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to set send compressor: %v", err)
}
}
@ -1935,6 +1954,60 @@ func SendHeader(ctx context.Context, md metadata.MD) error {
return nil
}
// SetSendCompressor sets a compressor for outbound messages from the server.
// It must not be called after any event that causes headers to be sent
// (see ServerStream.SetHeader for the complete list). Provided compressor is
// used when below conditions are met:
//
// - compressor is registered via encoding.RegisterCompressor
// - compressor name must exist in the client advertised compressor names
// sent in grpc-accept-encoding header. Use ClientSupportedCompressors to
// get client supported compressor names.
//
// The context provided must be the context passed to the server's handler.
// It must be noted that compressor name encoding.Identity disables the
// outbound compression.
// By default, server messages will be sent using the same compressor with
// which request messages were sent.
//
// It is not safe to call SetSendCompressor concurrently with SendHeader and
// SendMsg.
//
// # Experimental
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}
if err := validateSendCompressor(name, stream.ClientAdvertisedCompressors()); err != nil {
return fmt.Errorf("unable to set send compressor: %w", err)
}
return stream.SetSendCompress(name)
}
// ClientSupportedCompressors returns compressor names advertised by the client
// via grpc-accept-encoding header.
//
// The context provided must be the context passed to the server's handler.
//
// # Experimental
//
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}
return strings.Split(stream.ClientAdvertisedCompressors(), ","), nil
}
// SetTrailer sets the trailer metadata that will be sent when an RPC returns.
// When called more than once, all the provided metadata will be merged.
//
@ -1969,3 +2042,22 @@ type channelzServer struct {
func (c *channelzServer) ChannelzMetric() *channelz.ServerInternalMetric {
return c.s.channelzMetric()
}
// validateSendCompressor returns an error when given compressor name cannot be
// handled by the server or the client based on the advertised compressors.
func validateSendCompressor(name, clientCompressors string) error {
if name == encoding.Identity {
return nil
}
if !grpcutil.IsCompressorNameRegistered(name) {
return fmt.Errorf("compressor not registered %q", name)
}
for _, c := range strings.Split(clientCompressors, ",") {
if c == name {
return nil // found match
}
}
return fmt.Errorf("client does not support compressor %q", name)
}

View File

@ -1511,6 +1511,8 @@ type serverStream struct {
comp encoding.Compressor
decomp encoding.Compressor
sendCompressorName string
maxReceiveMessageSize int
maxSendMessageSize int
trInfo *traceInfo
@ -1603,6 +1605,13 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
}
}()
// Server handler could have set new compressor by calling SetSendCompressor.
// In case it is set, we need to use it for compressing outbound message.
if sendCompressorsName := ss.s.SendCompress(); sendCompressorsName != ss.sendCompressorName {
ss.comp = encoding.GetCompressor(sendCompressorsName)
ss.sendCompressorName = sendCompressorsName
}
// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, ss.codec, ss.cp, ss.comp)
if err != nil {

View File

@ -59,6 +59,7 @@ import (
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
@ -5080,6 +5081,340 @@ func (s) TestClientForwardsGrpcAcceptEncodingHeader(t *testing.T) {
}
}
// wrapCompressor is a wrapper of encoding.Compressor which maintains count of
// Compressor method invokes.
type wrapCompressor struct {
encoding.Compressor
compressInvokes int32
}
func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
atomic.AddInt32(&wc.compressInvokes, 1)
return wc.Compressor.Compress(w)
}
func setupGzipWrapCompressor(t *testing.T) *wrapCompressor {
oldC := encoding.GetCompressor("gzip")
c := &wrapCompressor{Compressor: oldC}
encoding.RegisterCompressor(c)
t.Cleanup(func() {
encoding.RegisterCompressor(oldC)
})
return c
}
func (s) TestSetSendCompressorSuccess(t *testing.T) {
for _, tt := range []struct {
name string
desc string
dialOpts []grpc.DialOption
resCompressor string
wantCompressInvokes int32
}{
{
name: "identity_request_and_gzip_response",
desc: "request is uncompressed and response is gzip compressed",
resCompressor: "gzip",
wantCompressInvokes: 1,
},
{
name: "gzip_request_and_identity_response",
desc: "request is gzip compressed and response is uncompressed with identity",
resCompressor: "identity",
dialOpts: []grpc.DialOption{
// Use WithCompressor instead of UseCompressor to avoid counting
// the client's compressor usage.
grpc.WithCompressor(grpc.NewGZIPCompressor()),
},
wantCompressInvokes: 0,
},
} {
t.Run(tt.name, func(t *testing.T) {
t.Run("unary", func(t *testing.T) {
testUnarySetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
})
t.Run("stream", func(t *testing.T) {
testStreamSetSendCompressorSuccess(t, tt.resCompressor, tt.wantCompressInvokes, tt.dialOpts)
})
})
}
}
func testUnarySetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
wc := setupGzipWrapCompressor(t)
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil, dialOpts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
}
compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
if compressInvokes != wantCompressInvokes {
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
}
}
func testStreamSetSendCompressorSuccess(t *testing.T, resCompressor string, wantCompressInvokes int32, dialOpts []grpc.DialOption) {
wc := setupGzipWrapCompressor(t)
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{})
},
}
if err := ss.Start(nil, dialOpts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
}
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
}
if _, err := s.Recv(); err != nil {
t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
}
compressInvokes := atomic.LoadInt32(&wc.compressInvokes)
if compressInvokes != wantCompressInvokes {
t.Fatalf("Unexpected compress invokes, got:%d, want: %d", compressInvokes, wantCompressInvokes)
}
}
func (s) TestUnregisteredSetSendCompressorFailure(t *testing.T) {
resCompressor := "snappy2"
wantErr := status.Error(codes.Unknown, "unable to set send compressor: compressor not registered \"snappy2\"")
t.Run("unary", func(t *testing.T) {
testUnarySetSendCompressorFailure(t, resCompressor, wantErr)
})
t.Run("stream", func(t *testing.T) {
testStreamSetSendCompressorFailure(t, resCompressor, wantErr)
})
}
func (s) TestUnadvertisedSetSendCompressorFailure(t *testing.T) {
// Disable client compressor advertisement.
defer func(b bool) { envconfig.AdvertiseCompressors = b }(envconfig.AdvertiseCompressors)
envconfig.AdvertiseCompressors = false
resCompressor := "gzip"
wantErr := status.Error(codes.Unknown, "unable to set send compressor: client does not support compressor \"gzip\"")
t.Run("unary", func(t *testing.T) {
testUnarySetSendCompressorFailure(t, resCompressor, wantErr)
})
t.Run("stream", func(t *testing.T) {
testStreamSetSendCompressorFailure(t, resCompressor, wantErr)
})
}
func testUnarySetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
if err := grpc.SetSendCompressor(ctx, resCompressor); err != nil {
return nil, err
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
}
}
func testStreamSetSendCompressorFailure(t *testing.T, resCompressor string, wantErr error) {
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
if _, err := stream.Recv(); err != nil {
return err
}
if err := grpc.SetSendCompressor(stream.Context(), resCompressor); err != nil {
return err
}
return stream.Send(&testpb.StreamingOutputCallResponse{})
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v, want: nil", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
}
if err := s.Send(&testpb.StreamingOutputCallRequest{}); err != nil {
t.Fatalf("Unexpected full duplex call send error, got: %v, want: nil", err)
}
if _, err := s.Recv(); !equalError(err, wantErr) {
t.Fatalf("Unexpected full duplex recv error, got: %v, want: nil", err)
}
}
func (s) TestUnarySetSendCompressorAfterHeaderSendFailure(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
// Send headers early and then set send compressor.
grpc.SendHeader(ctx, metadata.MD{})
err := grpc.SetSendCompressor(ctx, "gzip")
if err == nil {
t.Error("Wanted set send compressor error")
return &testpb.Empty{}, nil
}
return nil, err
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); !equalError(err, wantErr) {
t.Fatalf("Unexpected unary call error, got: %v, want: %v", err, wantErr)
}
}
func (s) TestStreamSetSendCompressorAfterHeaderSendFailure(t *testing.T) {
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
// Send headers early and then set send compressor.
grpc.SendHeader(stream.Context(), metadata.MD{})
err := grpc.SetSendCompressor(stream.Context(), "gzip")
if err == nil {
t.Error("Wanted set send compressor error")
}
return err
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
wantErr := status.Error(codes.Unknown, "transport: set send compressor called after headers sent or stream done")
s, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("Unexpected full duplex call error, got: %v, want: nil", err)
}
if _, err := s.Recv(); !equalError(err, wantErr) {
t.Fatalf("Unexpected full duplex recv error, got: %v, want: %v", err, wantErr)
}
}
func (s) TestClientSupportedCompressors(t *testing.T) {
for _, tt := range []struct {
desc string
ctx context.Context
want []string
}{
{
desc: "No additional grpc-accept-encoding header",
ctx: context.Background(),
want: []string{"gzip"},
},
{
desc: "With additional grpc-accept-encoding header",
ctx: metadata.AppendToOutgoingContext(context.Background(),
"grpc-accept-encoding", "test-compressor-1",
"grpc-accept-encoding", "test-compressor-2",
),
want: []string{"gzip", "test-compressor-1", "test-compressor-2"},
},
{
desc: "With additional empty grpc-accept-encoding header",
ctx: metadata.AppendToOutgoingContext(context.Background(),
"grpc-accept-encoding", "",
),
want: []string{"gzip"},
},
} {
t.Run(tt.desc, func(t *testing.T) {
ss := &stubserver.StubServer{
EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
got, err := grpc.ClientSupportedCompressors(ctx)
if err != nil {
return nil, err
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("unexpected client compressors got: %v, want: %v", got, tt.want)
}
return &testpb.Empty{}, nil
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v, want: nil", err)
}
defer ss.Stop()
ctx, cancel := context.WithTimeout(tt.ctx, defaultTestTimeout)
defer cancel()
_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
if err != nil {
t.Fatalf("Unexpected unary call error, got: %v, want: nil", err)
}
})
}
}
func (s) TestUnaryProxyDoesNotForwardMetadata(t *testing.T) {
const mdkey = "somedata"