alts: receive low watermark support

The implementation of setRcvlowat is based on the gRCP C++ library
implementation.

Part of #8510.
This commit is contained in:
Kevin Krakauer 2025-08-05 20:56:28 -07:00
parent 09c22f854f
commit fdd3cf5e27
3 changed files with 145 additions and 9 deletions

View File

@ -0,0 +1,26 @@
//go:build !linux
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
// SO_RCVLOWAT exists on non-Linux OSes, but we have't tested them.
func (p *conn) setRcvlowat(length int) error {
return nil
}

View File

@ -0,0 +1,79 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package conn
import (
"errors"
"golang.org/x/sys/unix"
)
// setRcvlowat updates SO_RCVLOWAT to reduce CPU usage.
func (p *conn) setRcvlowat(length int) error {
if p.rawConn == nil {
return nil
}
const (
rcvlowatMax = 16 * 1024 * 1024
rcvlowatMin = 32 * 1024
rcvlowatGap = 16 * 1024
)
remaining := min(cap(p.protected), length, rcvlowatMax)
// Small SO_RCVLOWAT values don't actually save CPU.
if remaining < rcvlowatMin {
remaining = 0
}
// Allow for a small gap, which can wake us up a tiny bit early. This
// helps with latency, as bytes can arrive between wakeup and the
// ensuing read.
if remaining > 0 {
remaining -= rcvlowatGap
}
// Don't hold up the socket once we've hit our threshold.
if len(p.protected) > remaining {
remaining = 0
}
// Don't enable SO_RCVLOWAT if it's not useful.
if p.rcvlowat <= 1 && remaining <= 1 {
return nil
}
// Don't make a syscall if nothing changed.
if p.rcvlowat == remaining {
return nil
}
// Make the actual setsockopt call.
var sockoptErr error
err := p.rawConn.Control(func(fd uintptr) {
sockoptErr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVLOWAT, p.rcvlowat)
})
if err != nil || sockoptErr != nil {
return errors.Join(err, sockoptErr)
}
p.rcvlowat = remaining
return nil
}

View File

@ -25,6 +25,7 @@ import (
"fmt"
"math"
"net"
"syscall"
core "google.golang.org/grpc/credentials/alts/internal"
)
@ -97,6 +98,18 @@ type conn struct {
nextFrame []byte
// overhead is the calculated overhead of each frame.
overhead int
// rcvlowat is the "receive low watermark" used to avoid unnecessary
// early returns from the kernel during [conn.Read], which saves CPU and
// can boost throughput under load. When we receive the first few bytes
// of a message we examine the length field. If, for example, we know
// there's 512KB of data remaining in the record, rcvlowat tells the
// kernel "don't wake me up every time you get another packet; wait
// until you have all 512KB."
//
// See SO_RCVLOWAT in tcp(7) for more info.
rcvlowat int
// rawConn allows us to set SO_RCVLOWAT on the underlying TCP socket.
rawConn syscall.RawConn
}
// NewConn creates a new secure channel instance given the other party role and
@ -129,6 +142,18 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
nextFrame: protectedBuf,
overhead: overhead,
}
if rcvlowat {
tcpConn, ok := c.(*net.TCPConn)
if !ok {
return nil, fmt.Errorf("rcvlowat requires a *net.TCPConn, but got %T", c)
}
if altsConn.rawConn, err = tcpConn.SyscallConn(); err != nil {
return nil, fmt.Errorf("failed to get raw connection: %w", err)
}
altsConn.rcvlowat = 1
}
return altsConn, nil
}
@ -139,7 +164,8 @@ func NewConn(c net.Conn, side core.Side, recordProtocol string, key []byte, prot
func (p *conn) Read(b []byte) (n int, err error) {
if len(p.buf) == 0 {
var framedMsg []byte
framedMsg, err = p.parseFramedMsg(p.nextFrame, altsRecordLengthLimit)
var length uint32
framedMsg, length, err = p.parseFramedMsg(p.nextFrame, altsRecordLengthLimit)
if err != nil {
return n, err
}
@ -154,6 +180,10 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
// Check whether a complete frame has been received yet.
for len(framedMsg) == 0 {
if err := p.setRcvlowat(int(length)); err != nil {
return 0, err
}
if len(p.protected) == cap(p.protected) {
// We can parse the length header to know exactly how large
// the buffer needs to be to hold the entire frame.
@ -179,7 +209,7 @@ func (p *conn) Read(b []byte) (n int, err error) {
return 0, err
}
p.protected = p.protected[:len(p.protected)+n]
framedMsg, err = p.parseFramedMsg(p.protected, altsRecordLengthLimit)
framedMsg, length, err = p.parseFramedMsg(p.protected, altsRecordLengthLimit)
if err != nil {
return 0, err
}
@ -221,24 +251,25 @@ func (p *conn) Read(b []byte) (n int, err error) {
}
// parseFramedMsg parses the provided buffer and returns a frame of the format
// msgLength+msg iff a full frame is available.
func (p *conn) parseFramedMsg(b []byte, maxLen uint32) ([]byte, error) {
// msgLength+msg iff a full frame is available. It also returns the message
// length if available.
func (p *conn) parseFramedMsg(b []byte, maxLen uint32) ([]byte, uint32, error) {
// If the size field is not complete, return the provided buffer as
// remaining buffer.
p.nextFrame = b
length, sufficientBytes := parseMessageLength(b)
if !sufficientBytes {
return nil, nil
return nil, length, nil
}
if length > maxLen {
return nil, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
return nil, length, fmt.Errorf("received the frame length %d larger than the limit %d", length, maxLen)
}
if len(b) < int(length)+4 { // account for the first 4 msg length bytes.
if len(b) < int(length)+MsgLenFieldSize { // account for the first 4 msg length bytes.
// Frame is not complete yet.
return nil, nil
return nil, length, nil
}
p.nextFrame = b[MsgLenFieldSize+length:]
return b[:MsgLenFieldSize+length], nil
return b[:MsgLenFieldSize+length], length, nil
}
// parseMessageLength returns the message length based on frame header. It also