* Revert "credentials/alts: Add comments to clarify buffer sizing (#8232)"

This reverts commit be25d96c52.

* Revert "credentials/alts: Optimize reads (#8204)"

This reverts commit b368379ef8.
This commit is contained in:
Arjan Singh Bal 2025-04-08 23:21:49 +05:30 committed by GitHub
parent 25c750934e
commit 6bfa0ca35b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 95 deletions

View File

@ -54,10 +54,11 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
length, sufficientBytes := parseMessageLength(b)
if !sufficientBytes {
if len(b) < MsgLenFieldSize {
return nil, b, nil
}
msgLenField := b[:MsgLenFieldSize]
length := binary.LittleEndian.Uint32(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
@ -67,14 +68,3 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
}
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
}
// parseMessageLength returns the message length based on frame header. It also
// returns a boolean indicating if the buffer contains sufficient bytes to parse
// the length header. If there are insufficient bytes, (0, false) is returned.
func parseMessageLength(b []byte) (uint32, bool) {
if len(b) < MsgLenFieldSize {
return 0, false
}
msgLenField := b[:MsgLenFieldSize]
return binary.LittleEndian.Uint32(msgLenField), true
}

View File

@ -63,8 +63,6 @@ const (
// The maximum write buffer size. This *must* be multiple of
// altsRecordDefaultLength.
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
// The initial buffer used to read from the network.
altsReadBufferInitialSize = 32 * 1024 // 32KiB
)
var (
@ -85,7 +83,7 @@ type conn struct {
net.Conn
crypto ALTSRecordCrypto
// buf holds data that has been read from the connection and decrypted,
// but has not yet been returned by Read. It is a sub-slice of protected.
// but has not yet been returned by Read.
buf []byte
payloadLengthLimit int
// protected holds data read from the network but have not yet been
@ -113,13 +111,21 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
}
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
payloadLengthLimit := altsRecordDefaultLength - overhead
// We pre-allocate protected to be of size 32KB during initialization.
// We increase the size of the buffer by the required amount if it can't
// hold a complete encrypted record.
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
// Copy additional data from hanshaker service.
copy(protectedBuf, protected)
protectedBuf = protectedBuf[:len(protected)]
var protectedBuf []byte
if protected == nil {
// We pre-allocate protected to be of size
// 2*altsRecordDefaultLength-1 during initialization. We only
// read from the network into protected when protected does not
// contain a complete frame, which is at most
// altsRecordDefaultLength-1 (bytes). And we read at most
// altsRecordDefaultLength (bytes) data into protected at one
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
// to buffer data read from the network.
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
} else {
protectedBuf = make([]byte, len(protected))
copy(protectedBuf, protected)
}
altsConn := &conn{
Conn: c,
@ -156,26 +162,11 @@ func (p *conn) Read(b []byte) (n int, err error) {
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if len(p.protected) == cap(p.protected) {
// We can parse the length header to know exactly how large
// the buffer needs to be to hold the entire frame.
length, didParse := parseMessageLength(p.protected)
if !didParse {
// The protected buffer is initialized with a capacity of
// larger than 4B. It should always hold the message length
// header.
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
}
oldProtectedBuf := p.protected
// The new buffer must be able to hold the message length header
// and the entire message.
requiredCapacity := int(length) + MsgLenFieldSize
p.protected = make([]byte, requiredCapacity)
// Copy the contents of the old buffer and set the length of the
// new buffer to the number of bytes already read.
copy(p.protected, oldProtectedBuf)
p.protected = p.protected[:len(oldProtectedBuf)]
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
copy(tmp, p.protected)
p.protected = tmp
}
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
if err != nil {
return 0, err
}
@ -194,15 +185,6 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
ciphertext := msg[msgTypeFieldSize:]
// Decrypt directly into the buffer, avoiding a copy from p.buf if
// possible.
if len(b) >= len(ciphertext) {
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
if err != nil {
return 0, err
}
return len(dec), nil
}
// Decrypt requires that if the dst and ciphertext alias, they
// must alias exactly. Code here used to use msg[:0], but msg
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than

View File

@ -26,7 +26,6 @@ import (
"math"
"net"
"reflect"
"strings"
"testing"
core "google.golang.org/grpc/credentials/alts/internal"
@ -189,48 +188,6 @@ func (s) TestLargeMsg(t *testing.T) {
}
}
// TestLargeRecord writes a very large ALTS record and verifies that the server
// receives it correctly. The large ALTS record should cause the reader to
// expand it's read buffer to hold the entire record and store the decrypted
// message until the receiver reads all of the bytes.
func (s) TestLargeRecord(t *testing.T) {
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
// Increase the size of ALTS records written by the client.
clientConn.payloadLengthLimit = math.MaxInt32
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
rcvMsg := make([]byte, len(msg))
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
if !reflect.DeepEqual(msg, rcvMsg) {
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
}
}
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
// receiving a large message.
func BenchmarkLargeMessage(b *testing.B) {
msgLen := 20 * 1024 * 1024 // 20 MiB
msg := make([]byte, msgLen)
rcvMsg := make([]byte, len(msg))
b.ResetTimer()
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
for range b.N {
// Write 20 MiB 5 times to transfer a total of 100 MiB.
for range 5 {
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
}
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
}
}
}
}
func testIncorrectMsgType(t *testing.T, rp string) {
// framedMsg is an empty ciphertext with correct framing but wrong
// message type.

View File

@ -308,7 +308,6 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
var lastWriteTime time.Time
buf := make([]byte, frameLimit)
for {
if len(resp.OutFrames) > 0 {
lastWriteTime = time.Now()
@ -319,6 +318,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
if resp.Result != nil {
return resp.Result, extra, nil
}
buf := make([]byte, frameLimit)
n, err := h.conn.Read(buf)
if err != nil && err != io.EOF {
return nil, nil, err