mirror of https://github.com/grpc/grpc-go.git
grpc: fix message length checks when compression is enabled and maxReceiveMessageSize is MaxInt (#7918)
This commit is contained in:
parent
67bee55a47
commit
8cf8fd1433
69
rpc_util.go
69
rpc_util.go
|
@ -828,30 +828,13 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
|
|||
return nil, st.Err()
|
||||
}
|
||||
|
||||
var size int
|
||||
if pf.isCompressed() {
|
||||
defer compressed.Free()
|
||||
|
||||
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
|
||||
// use this decompressor as the default.
|
||||
if dc != nil {
|
||||
var uncompressedBuf []byte
|
||||
uncompressedBuf, err = dc.Do(compressed.Reader())
|
||||
if err == nil {
|
||||
out = mem.BufferSlice{mem.SliceBuffer(uncompressedBuf)}
|
||||
}
|
||||
size = len(uncompressedBuf)
|
||||
} else {
|
||||
out, size, err = decompress(compressor, compressed, maxReceiveMessageSize, p.bufferPool)
|
||||
}
|
||||
out, err = decompress(compressor, compressed, dc, maxReceiveMessageSize, p.bufferPool)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
|
||||
}
|
||||
if size > maxReceiveMessageSize {
|
||||
out.Free()
|
||||
// TODO: Revisit the error code. Currently keep it consistent with java
|
||||
// implementation.
|
||||
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
out = compressed
|
||||
|
@ -866,20 +849,46 @@ func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveM
|
|||
return out, nil
|
||||
}
|
||||
|
||||
// Using compressor, decompress d, returning data and size.
|
||||
// Optionally, if data will be over maxReceiveMessageSize, just return the size.
|
||||
func decompress(compressor encoding.Compressor, d mem.BufferSlice, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, int, error) {
|
||||
dcReader, err := compressor.Decompress(d.Reader())
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
// decompress processes the given data by decompressing it using either a custom decompressor or a standard compressor.
|
||||
// If a custom decompressor is provided, it takes precedence. The function validates that the decompressed data
|
||||
// does not exceed the specified maximum size and returns an error if this limit is exceeded.
|
||||
// On success, it returns the decompressed data. Otherwise, it returns an error if decompression fails or the data exceeds the size limit.
|
||||
func decompress(compressor encoding.Compressor, d mem.BufferSlice, dc Decompressor, maxReceiveMessageSize int, pool mem.BufferPool) (mem.BufferSlice, error) {
|
||||
if dc != nil {
|
||||
uncompressed, err := dc.Do(d.Reader())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
|
||||
}
|
||||
if len(uncompressed) > maxReceiveMessageSize {
|
||||
return nil, status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", len(uncompressed), maxReceiveMessageSize)
|
||||
}
|
||||
return mem.BufferSlice{mem.SliceBuffer(uncompressed)}, nil
|
||||
}
|
||||
if compressor != nil {
|
||||
dcReader, err := compressor.Decompress(d.Reader())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the message: %v", err)
|
||||
}
|
||||
|
||||
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1), pool)
|
||||
if err != nil {
|
||||
out.Free()
|
||||
return nil, 0, err
|
||||
out, err := mem.ReadAll(io.LimitReader(dcReader, int64(maxReceiveMessageSize)), pool)
|
||||
if err != nil {
|
||||
out.Free()
|
||||
return nil, status.Errorf(codes.Internal, "grpc: failed to read decompressed data: %v", err)
|
||||
}
|
||||
|
||||
if out.Len() == maxReceiveMessageSize && !atEOF(dcReader) {
|
||||
out.Free()
|
||||
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", maxReceiveMessageSize)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
return out, out.Len(), nil
|
||||
return nil, status.Errorf(codes.Internal, "grpc: no decompressor available for compressed payload")
|
||||
}
|
||||
|
||||
// atEOF reads data from r and returns true if zero bytes could be read and r.Read returns EOF.
|
||||
func atEOF(dcReader io.Reader) bool {
|
||||
n, err := dcReader.Read(make([]byte, 1))
|
||||
return n == 0 && err == io.EOF
|
||||
}
|
||||
|
||||
type recvCompressor interface {
|
||||
|
|
127
rpc_util_test.go
127
rpc_util_test.go
|
@ -21,12 +21,17 @@ package grpc
|
|||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/encoding"
|
||||
_ "google.golang.org/grpc/encoding/gzip"
|
||||
protoenc "google.golang.org/grpc/encoding/proto"
|
||||
"google.golang.org/grpc/internal/testutils"
|
||||
"google.golang.org/grpc/internal/transport"
|
||||
|
@ -36,6 +41,11 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDecompressedData = "default decompressed data"
|
||||
decompressionErrorMsg = "invalid compression format"
|
||||
)
|
||||
|
||||
type fullReader struct {
|
||||
data []byte
|
||||
}
|
||||
|
@ -294,3 +304,120 @@ func BenchmarkGZIPCompressor512KiB(b *testing.B) {
|
|||
func BenchmarkGZIPCompressor1MiB(b *testing.B) {
|
||||
bmCompressor(b, 1024*1024, NewGZIPCompressor())
|
||||
}
|
||||
|
||||
// compressWithDeterministicError compresses the input data and returns a BufferSlice.
|
||||
func compressWithDeterministicError(t *testing.T, input []byte) mem.BufferSlice {
|
||||
t.Helper()
|
||||
var buf bytes.Buffer
|
||||
gz := gzip.NewWriter(&buf)
|
||||
if _, err := gz.Write(input); err != nil {
|
||||
t.Fatalf("compressInput() failed to write data: %v", err)
|
||||
}
|
||||
if err := gz.Close(); err != nil {
|
||||
t.Fatalf("compressInput() failed to close gzip writer: %v", err)
|
||||
}
|
||||
compressedData := buf.Bytes()
|
||||
return mem.BufferSlice{mem.NewBuffer(&compressedData, nil)}
|
||||
}
|
||||
|
||||
// MockDecompressor is a mock implementation of a decompressor used for testing purposes.
|
||||
// It simulates decompression behavior, returning either decompressed data or an error based on the ShouldError flag.
|
||||
type MockDecompressor struct {
|
||||
ShouldError bool // Flag to control whether the decompression should simulate an error.
|
||||
}
|
||||
|
||||
// Do simulates decompression. It returns a predefined error if ShouldError is true,
|
||||
// or a fixed set of decompressed data if ShouldError is false.
|
||||
func (m *MockDecompressor) Do(_ io.Reader) ([]byte, error) {
|
||||
if m.ShouldError {
|
||||
return nil, errors.New(decompressionErrorMsg)
|
||||
}
|
||||
return []byte(defaultDecompressedData), nil
|
||||
}
|
||||
|
||||
// Type returns the string identifier for the MockDecompressor.
|
||||
func (m *MockDecompressor) Type() string {
|
||||
return "MockDecompressor"
|
||||
}
|
||||
|
||||
// TestDecompress tests the decompress function behaves correctly for following scenarios
|
||||
// decompress successfully when message is <= maxReceiveMessageSize
|
||||
// errors when message > maxReceiveMessageSize
|
||||
// decompress successfully when maxReceiveMessageSize is MaxInt
|
||||
// errors when the decompressed message has an invalid format
|
||||
// errors when the decompressed message exceeds the maxReceiveMessageSize.
|
||||
func (s) TestDecompress(t *testing.T) {
|
||||
compressor := encoding.GetCompressor("gzip")
|
||||
validDecompressor := &MockDecompressor{ShouldError: false}
|
||||
invalidFormatDecompressor := &MockDecompressor{ShouldError: true}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input mem.BufferSlice
|
||||
dc Decompressor
|
||||
maxReceiveMessageSize int
|
||||
want []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "Decompresses successfully with sufficient buffer size",
|
||||
input: compressWithDeterministicError(t, []byte("decompressed data")),
|
||||
dc: nil,
|
||||
maxReceiveMessageSize: 50,
|
||||
want: []byte("decompressed data"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Fails due to exceeding maxReceiveMessageSize",
|
||||
input: compressWithDeterministicError(t, []byte("message that is too large")),
|
||||
dc: nil,
|
||||
maxReceiveMessageSize: len("message that is too large") - 1,
|
||||
want: nil,
|
||||
wantErr: status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max %d", len("message that is too large")-1),
|
||||
},
|
||||
{
|
||||
name: "Decompresses to exactly maxReceiveMessageSize",
|
||||
input: compressWithDeterministicError(t, []byte("exact size message")),
|
||||
dc: nil,
|
||||
maxReceiveMessageSize: len("exact size message"),
|
||||
want: []byte("exact size message"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Decompresses successfully with maxReceiveMessageSize MaxInt",
|
||||
input: compressWithDeterministicError(t, []byte("large message")),
|
||||
dc: nil,
|
||||
maxReceiveMessageSize: math.MaxInt,
|
||||
want: []byte("large message"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Fails with decompression error due to invalid format",
|
||||
input: compressWithDeterministicError(t, []byte("invalid compressed data")),
|
||||
dc: invalidFormatDecompressor,
|
||||
maxReceiveMessageSize: 50,
|
||||
want: nil,
|
||||
wantErr: status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", errors.New(decompressionErrorMsg)),
|
||||
},
|
||||
{
|
||||
name: "Fails with resourceExhausted error when decompressed message exceeds maxReceiveMessageSize",
|
||||
input: compressWithDeterministicError(t, []byte("large compressed data")),
|
||||
dc: validDecompressor,
|
||||
maxReceiveMessageSize: 20,
|
||||
want: nil,
|
||||
wantErr: status.Errorf(codes.ResourceExhausted, "grpc: message after decompression larger than max (%d vs. %d)", 25, 20),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
output, err := decompress(compressor, tc.input, tc.dc, tc.maxReceiveMessageSize, mem.DefaultBufferPool())
|
||||
if !cmp.Equal(err, tc.wantErr, cmpopts.EquateErrors()) {
|
||||
t.Fatalf("decompress() err = %v, wantErr = %v", err, tc.wantErr)
|
||||
}
|
||||
if !cmp.Equal(tc.want, output.Materialize()) {
|
||||
t.Fatalf("decompress() output mismatch: got = %v, want = %v", output.Materialize(), tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue