diff --git a/credentials/alts/internal/conn/common.go b/credentials/alts/internal/conn/common.go index 46617132a..d4c3ab798 100644 --- a/credentials/alts/internal/conn/common.go +++ b/credentials/alts/internal/conn/common.go @@ -19,9 +19,7 @@ package conn import ( - "encoding/binary" "errors" - "fmt" ) const ( @@ -48,33 +46,3 @@ func SliceForAppend(in []byte, n int) (head, tail []byte) { tail = head[len(in):] return head, tail } - -// ParseFramedMsg parse the provided buffer and returns a frame of the format -// msgLength+msg and any remaining bytes in that buffer. -func ParseFramedMsg(b []byte, maxLen uint32) ([]byte, []byte, error) { - // If the size field is not complete, return the provided buffer as - // remaining buffer. - length, sufficientBytes := parseMessageLength(b) - if !sufficientBytes { - return nil, b, nil - } - if length > maxLen { - return nil, nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen) - } - if len(b) < int(length)+4 { // account for the first 4 msg length bytes. - // Frame is not complete yet. - return nil, b, nil - } - return b[:MsgLenFieldSize+length], b[MsgLenFieldSize+length:], nil -} - -// parseMessageLength returns the message length based on frame header. It also -// returns a boolean indicating if the buffer contains sufficient bytes to parse -// the length header. If there are insufficient bytes, (0, false) is returned. -func parseMessageLength(b []byte) (uint32, bool) { - if len(b) < MsgLenFieldSize { - return 0, false - } - msgLenField := b[:MsgLenFieldSize] - return binary.LittleEndian.Uint32(msgLenField), true -} diff --git a/credentials/alts/internal/conn/record.go b/credentials/alts/internal/conn/record.go index f9d2646d4..45d09f130 100644 --- a/credentials/alts/internal/conn/record.go +++ b/credentials/alts/internal/conn/record.go @@ -144,7 +144,7 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot func (p *conn) Read(b []byte) (n int, err error) { if len(p.buf) == 0 { var framedMsg []byte - framedMsg, p.nextFrame, err = ParseFramedMsg(p.nextFrame, altsRecordLengthLimit) + framedMsg, err = p.parseFramedMsg(p.nextFrame, altsRecordLengthLimit) if err != nil { return n, err } @@ -184,7 +184,7 @@ func (p *conn) Read(b []byte) (n int, err error) { return 0, err } p.protected = p.protected[:len(p.protected)+n] - framedMsg, p.nextFrame, err = ParseFramedMsg(p.protected, altsRecordLengthLimit) + framedMsg, err = p.parseFramedMsg(p.protected, altsRecordLengthLimit) if err != nil { return 0, err } @@ -225,6 +225,38 @@ func (p *conn) Read(b []byte) (n int, err error) { return n, nil } +// parseFramedMsg parses the provided buffer and returns a frame of the format +// msgLength+msg iff a full frame is available. +func (p *conn) parseFramedMsg(b []byte, maxLen uint32) ([]byte, error) { + // If the size field is not complete, return the provided buffer as + // remaining buffer. + p.nextFrame = b + length, sufficientBytes := parseMessageLength(b) + if !sufficientBytes { + return nil, nil + } + if length > maxLen { + return nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen) + } + if len(b) < int(length)+4 { // account for the first 4 msg length bytes. + // Frame is not complete yet. + return nil, nil + } + p.nextFrame = b[MsgLenFieldSize+length:] + return b[:MsgLenFieldSize+length], nil +} + +// parseMessageLength returns the message length based on frame header. It also +// returns a boolean indicating if the buffer contains sufficient bytes to parse +// the length header. If there are insufficient bytes, (0, false) is returned. +func parseMessageLength(b []byte) (uint32, bool) { + if len(b) < MsgLenFieldSize { + return 0, false + } + msgLenField := b[:MsgLenFieldSize] + return binary.LittleEndian.Uint32(msgLenField), true +} + // Write encrypts, frames, and writes bytes from b to the underlying connection. func (p *conn) Write(b []byte) (n int, err error) { n = len(b)