mirror of https://github.com/grpc/grpc-go.git
credentials/alts: Optimize reads (#8204)
This commit is contained in:
parent
4b5505d301
commit
b368379ef8
|
@ -54,11 +54,10 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) {
|
||||||
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
|
func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
|
||||||
// If the size field is not complete, return the provided buffer as
|
// If the size field is not complete, return the provided buffer as
|
||||||
// remaining buffer.
|
// remaining buffer.
|
||||||
if len(b) < MsgLenFieldSize {
|
length, sufficientBytes := parseMessageLength(b)
|
||||||
|
if !sufficientBytes {
|
||||||
return nil, b, nil
|
return nil, b, nil
|
||||||
}
|
}
|
||||||
msgLenField := b[:MsgLenFieldSize]
|
|
||||||
length := binary.LittleEndian.Uint32(msgLenField)
|
|
||||||
if length > maxLen {
|
if length > maxLen {
|
||||||
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
|
return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
|
||||||
}
|
}
|
||||||
|
@ -68,3 +67,14 @@ func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) {
|
||||||
}
|
}
|
||||||
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
|
return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseMessageLength returns the message length based on frame header. It also
|
||||||
|
// returns a boolean indicating if the buffer contains sufficient bytes to parse
|
||||||
|
// the length header. If there are insufficient bytes, (0, false) is returned.
|
||||||
|
func parseMessageLength(b []byte) (uint32, bool) {
|
||||||
|
if len(b) < MsgLenFieldSize {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
msgLenField := b[:MsgLenFieldSize]
|
||||||
|
return binary.LittleEndian.Uint32(msgLenField), true
|
||||||
|
}
|
||||||
|
|
|
@ -63,6 +63,8 @@ const (
|
||||||
// The maximum write buffer size. This *must* be multiple of
|
// The maximum write buffer size. This *must* be multiple of
|
||||||
// altsRecordDefaultLength.
|
// altsRecordDefaultLength.
|
||||||
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
|
altsWriteBufferMaxSize = 512 * 1024 // 512KiB
|
||||||
|
// The initial buffer used to read from the network.
|
||||||
|
altsReadBufferInitialSize = 32 * 1024 // 32KiB
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -83,7 +85,7 @@ type conn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
crypto ALTSRecordCrypto
|
crypto ALTSRecordCrypto
|
||||||
// buf holds data that has been read from the connection and decrypted,
|
// buf holds data that has been read from the connection and decrypted,
|
||||||
// but has not yet been returned by Read.
|
// but has not yet been returned by Read. It is a sub-slice of protected.
|
||||||
buf []byte
|
buf []byte
|
||||||
payloadLengthLimit int
|
payloadLengthLimit int
|
||||||
// protected holds data read from the network but have not yet been
|
// protected holds data read from the network but have not yet been
|
||||||
|
@ -111,21 +113,13 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
|
||||||
}
|
}
|
||||||
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
|
overhead := MsgLenFieldSize + msgTypeFieldSize + crypto.EncryptionOverhead()
|
||||||
payloadLengthLimit := altsRecordDefaultLength - overhead
|
payloadLengthLimit := altsRecordDefaultLength - overhead
|
||||||
var protectedBuf []byte
|
// We pre-allocate protected to be of size 32KB during initialization.
|
||||||
if protected == nil {
|
// We increase the size of the buffer by the required amount if it can't
|
||||||
// We pre-allocate protected to be of size
|
// hold a complete encrypted record.
|
||||||
// 2*altsRecordDefaultLength-1 during initialization. We only
|
protectedBuf := make([]byte, max(altsReadBufferInitialSize, len(protected)))
|
||||||
// read from the network into protected when protected does not
|
// Copy additional data from hanshaker service.
|
||||||
// contain a complete frame, which is at most
|
|
||||||
// altsRecordDefaultLength-1 (bytes). And we read at most
|
|
||||||
// altsRecordDefaultLength (bytes) data into protected at one
|
|
||||||
// time. Therefore, 2*altsRecordDefaultLength-1 is large enough
|
|
||||||
// to buffer data read from the network.
|
|
||||||
protectedBuf = make([]byte, 0, 2*altsRecordDefaultLength-1)
|
|
||||||
} else {
|
|
||||||
protectedBuf = make([]byte, len(protected))
|
|
||||||
copy(protectedBuf, protected)
|
copy(protectedBuf, protected)
|
||||||
}
|
protectedBuf = protectedBuf[:len(protected)]
|
||||||
|
|
||||||
altsConn := &conn{
|
altsConn := &conn{
|
||||||
Conn: c,
|
Conn: c,
|
||||||
|
@ -162,11 +156,21 @@ func (p *conn) Read(b []byte) (n int, err error) {
|
||||||
// Check whether a complete frame has been received yet.
|
// Check whether a complete frame has been received yet.
|
||||||
for len(framedMsg) == 0 {
|
for len(framedMsg) == 0 {
|
||||||
if len(p.protected) == cap(p.protected) {
|
if len(p.protected) == cap(p.protected) {
|
||||||
tmp := make([]byte, len(p.protected), cap(p.protected)+altsRecordDefaultLength)
|
// We can parse the length header to know exactly how large
|
||||||
copy(tmp, p.protected)
|
// the buffer needs to be to hold the entire frame.
|
||||||
p.protected = tmp
|
length, didParse := parseMessageLength(p.protected)
|
||||||
|
if !didParse {
|
||||||
|
// The protected buffer is initialized with a capacity of
|
||||||
|
// larger than 4B. It should always hold the message length
|
||||||
|
// header.
|
||||||
|
panic(fmt.Sprintf("protected buffer length shorter than expected: %d vs %d", len(p.protected), MsgLenFieldSize))
|
||||||
}
|
}
|
||||||
n, err = p.Conn.Read(p.protected[len(p.protected):min(cap(p.protected), len(p.protected)+altsRecordDefaultLength)])
|
oldProtectedBuf := p.protected
|
||||||
|
p.protected = make([]byte, int(length)+MsgLenFieldSize)
|
||||||
|
copy(p.protected, oldProtectedBuf)
|
||||||
|
p.protected = p.protected[:len(oldProtectedBuf)]
|
||||||
|
}
|
||||||
|
n, err = p.Conn.Read(p.protected[len(p.protected):cap(p.protected)])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
@ -185,6 +189,15 @@ func (p *conn) Read(b []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
ciphertext := msg[msgTypeFieldSize:]
|
ciphertext := msg[msgTypeFieldSize:]
|
||||||
|
|
||||||
|
// Decrypt directly into the buffer, avoiding a copy from p.buf if
|
||||||
|
// possible.
|
||||||
|
if len(b) >= len(ciphertext) {
|
||||||
|
dec, err := p.crypto.Decrypt(b[:0], ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(dec), nil
|
||||||
|
}
|
||||||
// Decrypt requires that if the dst and ciphertext alias, they
|
// Decrypt requires that if the dst and ciphertext alias, they
|
||||||
// must alias exactly. Code here used to use msg[:0], but msg
|
// must alias exactly. Code here used to use msg[:0], but msg
|
||||||
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
|
// starts MsgLenFieldSize+msgTypeFieldSize bytes earlier than
|
||||||
|
|
|
@ -26,6 +26,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
core "google.golang.org/grpc/credentials/alts/internal"
|
core "google.golang.org/grpc/credentials/alts/internal"
|
||||||
|
@ -188,6 +189,48 @@ func (s) TestLargeMsg(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestLargeRecord writes a very large ALTS record and verifies that the server
|
||||||
|
// receives it correctly. The large ALTS record should cause the reader to
|
||||||
|
// expand it's read buffer to hold the entire record and store the decrypted
|
||||||
|
// message until the receiver reads all of the bytes.
|
||||||
|
func (s) TestLargeRecord(t *testing.T) {
|
||||||
|
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
|
||||||
|
msg := []byte(strings.Repeat("a", 2*altsReadBufferInitialSize))
|
||||||
|
// Increase the size of ALTS records written by the client.
|
||||||
|
clientConn.payloadLengthLimit = math.MaxInt32
|
||||||
|
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
|
||||||
|
t.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
||||||
|
}
|
||||||
|
rcvMsg := make([]byte, len(msg))
|
||||||
|
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
|
||||||
|
t.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(msg, rcvMsg) {
|
||||||
|
t.Fatalf("Write()/Server Read() = %v, want %v", rcvMsg, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLargeMessage measures the performance of ALTS conns for sending and
|
||||||
|
// receiving a large message.
|
||||||
|
func BenchmarkLargeMessage(b *testing.B) {
|
||||||
|
msgLen := 20 * 1024 * 1024 // 20 MiB
|
||||||
|
msg := make([]byte, msgLen)
|
||||||
|
rcvMsg := make([]byte, len(msg))
|
||||||
|
b.ResetTimer()
|
||||||
|
clientConn, serverConn := newConnPair(rekeyRecordProtocol, nil, nil)
|
||||||
|
for range b.N {
|
||||||
|
// Write 20 MiB 5 times to transfer a total of 100 MiB.
|
||||||
|
for range 5 {
|
||||||
|
if n, err := clientConn.Write(msg); n != len(msg) || err != nil {
|
||||||
|
b.Fatalf("Write() = %v, %v; want %v, <nil>", n, err, len(msg))
|
||||||
|
}
|
||||||
|
if n, err := io.ReadFull(serverConn, rcvMsg); n != len(rcvMsg) || err != nil {
|
||||||
|
b.Fatalf("Read() = %v, %v; want %v, <nil>", n, err, len(rcvMsg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func testIncorrectMsgType(t *testing.T, rp string) {
|
func testIncorrectMsgType(t *testing.T, rp string) {
|
||||||
// framedMsg is an empty ciphertext with correct framing but wrong
|
// framedMsg is an empty ciphertext with correct framing but wrong
|
||||||
// message type.
|
// message type.
|
||||||
|
|
|
@ -308,6 +308,7 @@ func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*al
|
||||||
// whatever received from the network and send it to the handshaker service.
|
// whatever received from the network and send it to the handshaker service.
|
||||||
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
|
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
|
||||||
var lastWriteTime time.Time
|
var lastWriteTime time.Time
|
||||||
|
buf := make([]byte, frameLimit)
|
||||||
for {
|
for {
|
||||||
if len(resp.OutFrames) > 0 {
|
if len(resp.OutFrames) > 0 {
|
||||||
lastWriteTime = time.Now()
|
lastWriteTime = time.Now()
|
||||||
|
@ -318,7 +319,6 @@ func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []b
|
||||||
if resp.Result != nil {
|
if resp.Result != nil {
|
||||||
return resp.Result, extra, nil
|
return resp.Result, extra, nil
|
||||||
}
|
}
|
||||||
buf := make([]byte, frameLimit)
|
|
||||||
n, err := h.conn.Read(buf)
|
n, err := h.conn.Read(buf)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|
Loading…
Reference in New Issue