mirror of https://github.com/grpc/grpc-go.git
* Revert "credentials/alts: Add comments to clarify buffer sizing (#8232)" This reverts commitbe25d96c52
. * Revert "credentials/alts: Optimize reads (#8204)" This reverts commitb368379ef8
.
This commit is contained in:
parent
25c750934e
commit
6bfa0ca35b
|
@ -54,10 +54,11 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
|
|||
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
|
||||
// If the size field is not complete, return the provided buffer as
|
||||
// remaining buffer.
|
||||
length, sufficientBytes := parseMessageLength(b)
|
||||
if !sufficientBytes {
|
||||
if len(b) < MsgLenFieldSize {
|
||||
return nil, b, nil
|
||||
}
|
||||
msgLenField := b[:MsgLenFieldSize]
|
||||
length := binary.LittleEndian.Uint32(msgLenField)
|
||||
if length > maxLen {
|
||||
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
|
||||
}
|
||||
|
@ -67,14 +68,3 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
|
|||
}
|
||||
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
|
||||
}
|
||||
|
||||
// parseMessageLength returns the message length based on frame header. It also
|
||||
// returns a boolean indicating if the buffer contains sufficient bytes to parse
|
||||
// the length header. If there are insufficient bytes, (0, false) is returned.
|
||||
func parseMessageLength(b []byte) (uint32, bool) {
|
||||
if len(b) < MsgLenFieldSize {
|
||||
return 0, false
|
||||
}
|
||||
msgLenField := b[:MsgLenFieldSize]
|
||||
return binary.LittleEndian.Uint32(msgLenField), true
|
||||
}
|
||||
|
|
|
@ -63,8 +63,6 @@ const (
|
|||
// The maximum write buffer size. This *must* be multiple of
|
||||
// altsRecordDefaultLength.
|
||||
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
|
||||
// The initial buffer used to read from the network.
|
||||
altsReadBufferInitialSize = 32 * 1024 // 32KiB
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -85,7 +83,7 @@ type conn struct {
|
|||
net.Conn
|
||||
crypto ALTSRecordCrypto
|
||||
// buf holds data that has been read from the connection and decrypted,
|
||||
// but has not yet been returned by Read. It is a sub-slice of protected.
|
||||
// but has not yet been returned by Read.
|
||||
buf []byte
|
||||
payloadLengthLimit int
|
||||
// protected holds data read from the network but have not yet been
|
||||
|
@ -113,13 +111,21 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
|
|||
}
|
||||
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
|
||||
payloadLengthLimit := altsRecordDefaultLength - overhead
|
||||
// We pre-allocate protected to be of size 32KB during initialization.
|
||||
// We increase the size of the buffer by the required amount if it can't
|
||||
// hold a complete encrypted record.
|
||||
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
|
||||
// Copy additional data from hanshaker service.
|
||||
copy(protectedBuf, protected)
|
||||
protectedBuf = protectedBuf[:len(protected)]
|
||||
var protectedBuf []byte
|
||||
if protected == nil {
|
||||
// We pre-allocate protected to be of size
|
||||
// 2*altsRecordDefaultLength-1 during initialization. We only
|
||||
// read from the network into protected when protected does not
|
||||
// contain a complete frame, which is at most
|
||||
// altsRecordDefaultLength-1 (bytes). And we read at most
|
||||
// altsRecordDefaultLength (bytes) data into protected at one
|
||||
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
|
||||
// to buffer data read from the network.
|
||||
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
|
||||
} else {
|
||||
protectedBuf = make([]byte, len(protected))
|
||||
copy(protectedBuf, protected)
|
||||
}
|
||||
|
||||
altsConn := &conn{
|
||||
Conn: c,
|
||||
|
@ -156,26 +162,11 @@ func (p *conn) Read(b []byte) (n int, err error) {
|
|||
// Check whether a complete frame has been received yet.
|
||||
for len(framedMsg) == 0 {
|
||||
if len(p.protected) == cap(p.protected) {
|
||||
// We can parse the length header to know exactly how large
|
||||
// the buffer needs to be to hold the entire frame.
|
||||
length, didParse := parseMessageLength(p.protected)
|
||||
if !didParse {
|
||||
// The protected buffer is initialized with a capacity of
|
||||
// larger than 4B. It should always hold the message length
|
||||
// header.
|
||||
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
|
||||
}
|
||||
oldProtectedBuf := p.protected
|
||||
// The new buffer must be able to hold the message length header
|
||||
// and the entire message.
|
||||
requiredCapacity := int(length) + MsgLenFieldSize
|
||||
p.protected = make([]byte, requiredCapacity)
|
||||
// Copy the contents of the old buffer and set the length of the
|
||||
// new buffer to the number of bytes already read.
|
||||
copy(p.protected, oldProtectedBuf)
|
||||
p.protected = p.protected[:len(oldProtectedBuf)]
|
||||
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
|
||||
copy(tmp, p.protected)
|
||||
p.protected = tmp
|
||||
}
|
||||
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
|
||||
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -194,15 +185,6 @@ func (p *conn) Read(b []byte) (n int, err error) {
|
|||
}
|
||||
ciphertext := msg[msgTypeFieldSize:]
|
||||
|
||||
// Decrypt directly into the buffer, avoiding a copy from p.buf if
|
||||
// possible.
|
||||
if len(b) >= len(ciphertext) {
|
||||
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(dec), nil
|
||||
}
|
||||
// Decrypt requires that if the dst and ciphertext alias, they
|
||||
// must alias exactly. Code here used to use msg[:0], but msg
|
||||
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
|
||||
|
|
|
@ -26,7 +26,6 @@ import (
|
|||
"math"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
core "google.golang.org/grpc/credentials/alts/internal"
|
||||
|
@ -189,48 +188,6 @@ func (s) TestLargeMsg(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestLargeRecord writes a very large ALTS record and verifies that the server
|
||||
// receives it correctly. The large ALTS record should cause the reader to
|
||||
// expand it's read buffer to hold the entire record and store the decrypted
|
||||
// message until the receiver reads all of the bytes.
|
||||
func (s) TestLargeRecord(t *testing.T) {
|
||||
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
|
||||
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
|
||||
// Increase the size of ALTS records written by the client.
|
||||
clientConn.payloadLengthLimit = math.MaxInt32
|
||||
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
|
||||
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
||||
}
|
||||
rcvMsg := make([]byte, len(msg))
|
||||
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
|
||||
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
|
||||
}
|
||||
if !reflect.DeepEqual(msg, rcvMsg) {
|
||||
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
|
||||
// receiving a large message.
|
||||
func BenchmarkLargeMessage(b *testing.B) {
|
||||
msgLen := 20 * 1024 * 1024 // 20 MiB
|
||||
msg := make([]byte, msgLen)
|
||||
rcvMsg := make([]byte, len(msg))
|
||||
b.ResetTimer()
|
||||
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
|
||||
for range b.N {
|
||||
// Write 20 MiB 5 times to transfer a total of 100 MiB.
|
||||
for range 5 {
|
||||
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
|
||||
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
||||
}
|
||||
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
|
||||
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testIncorrectMsgType(t *testing.T, rp string) {
|
||||
// framedMsg is an empty ciphertext with correct framing but wrong
|
||||
// message type.
|
||||
|
|
|
@ -308,7 +308,6 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
|
|||
// whatever received from the network and send it to the handshaker service.
|
||||
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
|
||||
var lastWriteTime time.Time
|
||||
buf := make([]byte, frameLimit)
|
||||
for {
|
||||
if len(resp.OutFrames) > 0 {
|
||||
lastWriteTime = time.Now()
|
||||
|
@ -319,6 +318,7 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
|
|||
if resp.Result != nil {
|
||||
return resp.Result, extra, nil
|
||||
}
|
||||
buf := make([]byte, frameLimit)
|
||||
n, err := h.conn.Read(buf)
|
||||
if err != nil && err != io.EOF {
|
||||
return nil, nil, err
|
||||
|
|
Loading…
Reference in New Issue