transport/bufWriter: fast-fail on error returned from flushKeepBuffer() (#7394)

This commit is contained in:
Oleg Guba 2024-08-07 12:07:18 -07:00 committed by GitHub
parent 1490d60f47
commit ffaa81e286
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 9 deletions

View File

@ -317,28 +317,32 @@ func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter {
return w
}
func (w *bufWriter) Write(b []byte) (n int, err error) {
func (w *bufWriter) Write(b []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
if w.batchSize == 0 { // Buffer has been disabled.
n, err = w.conn.Write(b)
n, err := w.conn.Write(b)
return n, toIOError(err)
}
if w.buf == nil {
b := w.pool.Get().(*[]byte)
w.buf = *b
}
written := 0
for len(b) > 0 {
nn := copy(w.buf[w.offset:], b)
b = b[nn:]
w.offset += nn
n += nn
if w.offset >= w.batchSize {
err = w.flushKeepBuffer()
copied := copy(w.buf[w.offset:], b)
b = b[copied:]
written += copied
w.offset += copied
if w.offset < w.batchSize {
continue
}
if err := w.flushKeepBuffer(); err != nil {
return written, err
}
}
return n, err
return written, nil
}
func (w *bufWriter) Flush() error {

View File

@ -19,7 +19,10 @@
package transport
import (
"errors"
"fmt"
"io"
"net"
"reflect"
"testing"
"time"
@ -215,6 +218,39 @@ func (s) TestParseDialTarget(t *testing.T) {
}
}
type badNetworkConn struct {
net.Conn
}
func (c *badNetworkConn) Write([]byte) (int, error) {
return 0, io.EOF
}
// This test ensures Write() on a broken network connection does not lead to
// an infinite loop. See https://github.com/grpc/grpc-go/issues/7389 for more details.
func (s) TestWriteBadConnection(t *testing.T) {
data := []byte("test_data")
// Configure the bufWriter with a batchsize that results in data being flushed
// to the underlying conn, midway through Write().
writeBufferSize := (len(data) - 1) / 2
writer := newBufWriter(&badNetworkConn{}, writeBufferSize, getWriteBufferPool(writeBufferSize))
errCh := make(chan error, 1)
go func() {
_, err := writer.Write(data)
errCh <- err
}()
select {
case <-time.After(time.Second):
t.Fatalf("Write() did not return in time")
case err := <-errCh:
if !errors.Is(err, io.EOF) {
t.Fatalf("Write() = %v, want error presence = %v", err, io.EOF)
}
}
}
func BenchmarkDecodeGrpcMessage(b *testing.B) {
input := "Hello, %E4%B8%96%E7%95%8C"
want := "Hello, 世界"