diff --git a/credentials/alts/internal/conn/conn.go b/credentials/alts/internal/conn/conn.go new file mode 100644 index 000000000..10f792497 --- /dev/null +++ b/credentials/alts/internal/conn/conn.go @@ -0,0 +1,26 @@ +//go:build !linux + +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +// SO_RCVLOWAT exists on non-Linux OSes, but we have't tested them. +func (p *conn) setRcvlowat(length int) error { + return nil +} diff --git a/credentials/alts/internal/conn/conn_linux.go b/credentials/alts/internal/conn/conn_linux.go new file mode 100644 index 000000000..5bba24b8b --- /dev/null +++ b/credentials/alts/internal/conn/conn_linux.go @@ -0,0 +1,79 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package conn + +import ( + "errors" + + "golang.org/x/sys/unix" +) + +// setRcvlowat updates SO_RCVLOWAT to reduce CPU usage. +func (p *conn) setRcvlowat(length int) error { + if p.rawConn == nil { + return nil + } + + const ( + rcvlowatMax = 16 * 1024 * 1024 + rcvlowatMin = 32 * 1024 + rcvlowatGap = 16 * 1024 + ) + + remaining := min(cap(p.protected), length, rcvlowatMax) + + // Small SO_RCVLOWAT values don't actually save CPU. + if remaining < rcvlowatMin { + remaining = 0 + } + + // Allow for a small gap, which can wake us up a tiny bit early. This + // helps with latency, as bytes can arrive between wakeup and the + // ensuing read. + if remaining > 0 { + remaining -= rcvlowatGap + } + + // Don't hold up the socket once we've hit our threshold. + if len(p.protected) > remaining { + remaining = 0 + } + + // Don't enable SO_RCVLOWAT if it's not useful. + if p.rcvlowat <= 1 && remaining <= 1 { + return nil + } + + // Don't make a syscall if nothing changed. + if p.rcvlowat == remaining { + return nil + } + + // Make the actual setsockopt call. + var sockoptErr error + err := p.rawConn.Control(func(fd uintptr) { + sockoptErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVLOWAT, p.rcvlowat) + }) + if err != nil || sockoptErr != nil { + return errors.Join(err, sockoptErr) + } + + p.rcvlowat = remaining + return nil +} diff --git a/credentials/alts/internal/conn/record.go b/credentials/alts/internal/conn/record.go index 67f43af24..bfc273d59 100644 --- a/credentials/alts/internal/conn/record.go +++ b/credentials/alts/internal/conn/record.go @@ -25,6 +25,7 @@ import ( "fmt" "math" "net" + "syscall" core "google.golang.org/grpc/credentials/alts/internal" ) @@ -97,6 +98,18 @@ type conn struct { nextFrame []byte // overhead is the calculated overhead of each frame. overhead int + // rcvlowat is the "receive low watermark" used to avoid unnecessary + // early returns from the kernel during [conn.Read], which saves CPU and + // can boost throughput under load. When we receive the first few bytes + // of a message we examine the length field. If, for example, we know + // there's 512KB of data remaining in the record, rcvlowat tells the + // kernel "don't wake me up every time you get another packet; wait + // until you have all 512KB." + // + // See SO_RCVLOWAT in tcp(7) for more info. + rcvlowat int + // rawConn allows us to set SO_RCVLOWAT on the underlying TCP socket. + rawConn syscall.RawConn } // NewConn creates a new secure channel instance given the other party role and @@ -129,6 +142,18 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot nextFrame: protectedBuf, overhead: overhead, } + + if rcvlowat { + tcpConn, ok := c.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("rcvlowat requires a *net.TCPConn, but got %T", c) + } + if altsConn.rawConn, err = tcpConn.SyscallConn(); err != nil { + return nil, fmt.Errorf("failed to get raw connection: %w", err) + } + altsConn.rcvlowat = 1 + } + return altsConn, nil } @@ -139,7 +164,8 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot func (p *conn) Read(b []byte) (n int, err error) { if len(p.buf) == 0 { var framedMsg []byte - framedMsg, err = p.parseFramedMsg(p.nextFrame, altsRecordLengthLimit) + var length uint32 + framedMsg, length, err = p.parseFramedMsg(p.nextFrame, altsRecordLengthLimit) if err != nil { return n, err } @@ -154,6 +180,10 @@ func (p *conn) Read(b []byte) (n int, err error) { } // Check whether a complete frame has been received yet. for len(framedMsg) == 0 { + if err := p.setRcvlowat(int(length)); err != nil { + return 0, err + } + if len(p.protected) == cap(p.protected) { // We can parse the length header to know exactly how large // the buffer needs to be to hold the entire frame. @@ -179,7 +209,7 @@ func (p *conn) Read(b []byte) (n int, err error) { return 0, err } p.protected = p.protected[:len(p.protected)+n] - framedMsg, err = p.parseFramedMsg(p.protected, altsRecordLengthLimit) + framedMsg, length, err = p.parseFramedMsg(p.protected, altsRecordLengthLimit) if err != nil { return 0, err } @@ -221,24 +251,25 @@ func (p *conn) Read(b []byte) (n int, err error) { } // parseFramedMsg parses the provided buffer and returns a frame of the format -// msgLength+msg iff a full frame is available. -func (p *conn) parseFramedMsg(b []byte, maxLen uint32) ([]byte, error) { +// msgLength+msg iff a full frame is available. It also returns the message +// length if available. +func (p *conn) parseFramedMsg(b []byte, maxLen uint32) ([]byte, uint32, error) { // If the size field is not complete, return the provided buffer as // remaining buffer. p.nextFrame = b length, sufficientBytes := parseMessageLength(b) if !sufficientBytes { - return nil, nil + return nil, length, nil } if length > maxLen { - return nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen) + return nil, length, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen) } - if len(b) < int(length)+4 { // account for the first 4 msg length bytes. + if len(b) < int(length)+MsgLenFieldSize { // account for the first 4 msg length bytes. // Frame is not complete yet. - return nil, nil + return nil, length, nil } p.nextFrame = b[MsgLenFieldSize+length:] - return b[:MsgLenFieldSize+length], nil + return b[:MsgLenFieldSize+length], length, nil } // parseMessageLength returns the message length based on frame header. It also