client: synchronously verify server preface in newClientTransport (#5731)

This commit is contained in:
Doug Fawley 2022-10-20 09:29:17 -07:00 committed by GitHub
parent f51d21267d
commit 9127159caf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 267 additions and 224 deletions

View File

@ -1228,38 +1228,33 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
// address was not successfully connected, or updates ac appropriately with the
// new transport.
func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
// TODO: Delete prefaceReceived and move the logic to wait for it into the
// transport.
prefaceReceived := grpcsync.NewEvent()
connClosed := grpcsync.NewEvent()
addr.ServerName = ac.cc.getServerName(addr)
hctx, hcancel := context.WithCancel(ac.ctx)
hcStarted := false // protected by ac.mu
onClose := func() {
onClose := grpcsync.OnceFunc(func() {
ac.mu.Lock()
defer ac.mu.Unlock()
defer connClosed.Fire()
defer hcancel()
if !hcStarted || hctx.Err() != nil {
// We didn't start the health check or set the state to READY, so
// no need to do anything else here.
//
// OR, we have already cancelled the health check context, meaning
// we have already called onClose once for this transport. In this
// case it would be dangerous to clear the transport and update the
// state, since there may be a new transport in this addrConn.
if ac.state == connectivity.Shutdown {
// Already shut down. tearDown() already cleared the transport and
// canceled hctx via ac.ctx, and we expected this connection to be
// closed, so do nothing here.
return
}
hcancel()
if ac.transport == nil {
// We're still connecting to this address, which could error. Do
// not update the connectivity state or resolve; these will happen
// at the end of the tryAllAddrs connection loop in the event of an
// error.
return
}
ac.transport = nil
// Refresh the name resolver
// Refresh the name resolver on any connection loss.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
}
// Always go idle and wait for the LB policy to initiate a new
// connection attempt.
ac.updateConnectivityState(connectivity.Idle, nil)
})
onGoAway := func(r transport.GoAwayReason) {
ac.mu.Lock()
ac.adjustParams(r)
@ -1271,7 +1266,7 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
defer cancel()
copts.ChannelzParentID = ac.channelzID
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, func() { prefaceReceived.Fire() }, onGoAway, onClose)
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onGoAway, onClose)
if err != nil {
// newTr is either nil, or closed.
hcancel()
@ -1279,66 +1274,34 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
return err
}
select {
case <-connectCtx.Done():
// We didn't get the preface in time.
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.state == connectivity.Shutdown {
// This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not
// have been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose()
// inline, which requires locking ac.mu.
//
// The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error.
newTr.Close(transport.ErrConnClosing)
if connectCtx.Err() == context.DeadlineExceeded {
err := errors.New("failed to receive server preface within timeout")
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %s: %v", addr, err)
return err
}
go newTr.Close(transport.ErrConnClosing)
return nil
case <-prefaceReceived.Done():
// We got the preface - huzzah! things are good.
ac.mu.Lock()
defer ac.mu.Unlock()
if connClosed.HasFired() {
// onClose called first; go idle but do nothing else.
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
return nil
}
if ac.state == connectivity.Shutdown {
// This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not
// been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose()
// inline, which requires locking ac.mu.
//
// The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error.
go newTr.Close(transport.ErrConnClosing)
return nil
}
ac.curAddr = addr
ac.transport = newTr
hcStarted = true
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
return nil
case <-connClosed.Done():
// The transport has already closed. If we received the preface, too,
// this is not an error and go idle.
select {
case <-prefaceReceived.Done():
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
return nil
default:
return errors.New("connection closed before server preface received")
}
}
if hctx.Err() != nil {
// onClose was already called for this connection, but the connection
// was successfully established first. Consider it a success and set
// the new state to Idle.
ac.updateConnectivityState(connectivity.Idle, nil)
return nil
}
ac.curAddr = addr
ac.transport = newTr
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
return nil
}
// startHealthCheck starts the health checking stream (RPC) to watch the health

View File

@ -0,0 +1,32 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpcsync
import (
"sync"
)
// OnceFunc returns a function wrapping f which ensures f is only executed
// once even if the returned function is executed multiple times.
func OnceFunc(f func()) func() {
var once sync.Once
return func() {
once.Do(f)
}
}

View File

@ -0,0 +1,53 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package grpcsync
import (
"sync"
"sync/atomic"
"testing"
"time"
)
// TestOnceFunc tests that a OnceFunc is executed only once even with multiple
// simultaneous callers of it.
func (s) TestOnceFunc(t *testing.T) {
var v int32
inc := OnceFunc(func() { atomic.AddInt32(&v, 1) })
const numWorkers = 100
var wg sync.WaitGroup // Blocks until all workers have called inc.
wg.Add(numWorkers)
block := NewEvent() // Signal to worker goroutines to call inc
for i := 0; i < numWorkers; i++ {
go func() {
<-block.Done() // Wait for a signal.
inc() // Call the OnceFunc.
wg.Done()
}()
}
time.Sleep(time.Millisecond) // Allow goroutines to get to the block.
block.Fire() // Unblock them.
wg.Wait() // Wait for them to complete.
if v != 1 {
t.Fatalf("OnceFunc() called %v times; want 1", v)
}
}

View File

@ -38,6 +38,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/channelz"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
istatus "google.golang.org/grpc/internal/status"
@ -100,10 +101,6 @@ type http2Client struct {
maxSendHeaderListSize *uint32
bdpEst *bdpEstimator
// onPrefaceReceipt is a callback that client transport calls upon
// receiving server preface to signal that a succefull HTTP2
// connection was established.
onPrefaceReceipt func()
maxConcurrentStreams uint32
streamQuota int64
@ -196,7 +193,7 @@ func isTemporary(err error) bool {
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onPrefaceReceipt func(), onGoAway func(GoAwayReason), onClose func()) (_ *http2Client, err error) {
func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onGoAway func(GoAwayReason), onClose func()) (_ *http2Client, err error) {
scheme := "http"
ctx, cancel := context.WithCancel(ctx)
defer func() {
@ -218,12 +215,35 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}
return nil, connectionErrorf(true, err, "transport: Error while dialing %v", err)
}
// Any further errors will close the underlying connection
defer func(conn net.Conn) {
if err != nil {
conn.Close()
}
}(conn)
// The following defer and goroutine monitor the connectCtx for cancelation
// and deadline. On context expiration, the connection is hard closed and
// this function will naturally fail as a result. Otherwise, the defer
// waits for the goroutine to exit to prevent the context from being
// monitored (and to prevent the connection from ever being closed) after
// returning from this function.
ctxMonitorDone := grpcsync.NewEvent()
newClientCtx, newClientDone := context.WithCancel(connectCtx)
defer func() {
newClientDone() // Awaken the goroutine below if connectCtx hasn't expired.
<-ctxMonitorDone.Done() // Wait for the goroutine below to exit.
}()
go func(conn net.Conn) {
defer ctxMonitorDone.Fire() // Signal this goroutine has exited.
<-newClientCtx.Done() // Block until connectCtx expires or the defer above executes.
if connectCtx.Err() != nil {
// connectCtx expired before exiting the function. Hard close the connection.
conn.Close()
}
}(conn)
kp := opts.KeepaliveParams
// Validate keepalive parameters.
if kp.Time == 0 {
@ -255,15 +275,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
}
}
if transportCreds != nil {
rawConn := conn
// Pull the deadline from the connectCtx, which will be used for
// timeouts in the authentication protocol handshake. Can ignore the
// boolean as the deadline will return the zero value, which will make
// the conn not timeout on I/O operations.
deadline, _ := connectCtx.Deadline()
rawConn.SetDeadline(deadline)
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, rawConn)
rawConn.SetDeadline(time.Time{})
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
@ -318,16 +330,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
kp: kp,
statsHandlers: opts.StatsHandlers,
initialWindowSize: initialWindowSize,
onPrefaceReceipt: onPrefaceReceipt,
nextID: 1,
maxConcurrentStreams: defaultMaxStreamsClient,
streamQuota: defaultMaxStreamsClient,
streamsQuotaAvailable: make(chan struct{}, 1),
czData: new(channelzData),
onGoAway: onGoAway,
onClose: onClose,
keepaliveEnabled: keepaliveEnabled,
bufferPool: newBufferPool(),
onClose: onClose,
}
// Add peer information to the http2client context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())
@ -366,21 +377,32 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
t.kpDormancyCond = sync.NewCond(&t.mu)
go t.keepalive()
}
// Start the reader goroutine for incoming message. Each transport has
// a dedicated goroutine which reads HTTP2 frame from network. Then it
// dispatches the frame to the corresponding stream entity.
go t.reader()
// Start the reader goroutine for incoming messages. Each transport has a
// dedicated goroutine which reads HTTP2 frames from the network. Then it
// dispatches the frame to the corresponding stream entity. When the
// server preface is received, readerErrCh is closed. If an error occurs
// first, an error is pushed to the channel. This must be checked before
// returning from this function.
readerErrCh := make(chan error, 1)
go t.reader(readerErrCh)
defer func() {
if err == nil {
err = <-readerErrCh
}
if err != nil {
t.Close(err)
}
}()
// Send connection preface to server.
n, err := t.conn.Write(clientPreface)
if err != nil {
err = connectionErrorf(true, err, "transport: failed to write client preface: %v", err)
t.Close(err)
return nil, err
}
if n != len(clientPreface) {
err = connectionErrorf(true, nil, "transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface))
t.Close(err)
return nil, err
}
var ss []http2.Setting
@ -400,14 +422,12 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
err = t.framer.fr.WriteSettings(ss...)
if err != nil {
err = connectionErrorf(true, err, "transport: failed to write initial settings frame: %v", err)
t.Close(err)
return nil, err
}
// Adjust the connection flow control window if needed.
if delta := uint32(icwz - defaultWindowSize); delta > 0 {
if err := t.framer.fr.WriteWindowUpdate(0, delta); err != nil {
err = connectionErrorf(true, err, "transport: failed to write window update: %v", err)
t.Close(err)
return nil, err
}
}
@ -907,19 +927,15 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// Close kicks off the shutdown process of the transport. This should be called
// only once on a transport. Once it is called, the transport should not be
// accessed any more.
//
// This method blocks until the addrConn that initiated this transport is
// re-connected. This happens because t.onClose() begins reconnect logic at the
// addrConn level and blocks until the addrConn is successfully connected.
func (t *http2Client) Close(err error) {
t.mu.Lock()
// Make sure we only Close once.
// Make sure we only close once.
if t.state == closing {
t.mu.Unlock()
return
}
// Call t.onClose before setting the state to closing to prevent the client
// from attempting to create new streams ASAP.
// Call t.onClose ASAP to prevent the client from attempting to create new
// streams.
t.onClose()
t.state = closing
streams := t.activeStreams
@ -1509,33 +1525,35 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true)
}
// reader runs as a separate goroutine in charge of reading data from network
// connection.
//
// TODO(zhaoq): currently one reader per transport. Investigate whether this is
// optimal.
// TODO(zhaoq): Check the validity of the incoming frame sequence.
func (t *http2Client) reader() {
defer close(t.readerDone)
// Check the validity of server preface.
// readServerPreface reads and handles the initial settings frame from the
// server.
func (t *http2Client) readServerPreface() error {
frame, err := t.framer.fr.ReadFrame()
if err != nil {
err = connectionErrorf(true, err, "error reading server preface: %v", err)
t.Close(err) // this kicks off resetTransport, so must be last before return
return
}
t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!)
if t.keepaliveEnabled {
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
return connectionErrorf(true, err, "error reading server preface: %v", err)
}
sf, ok := frame.(*http2.SettingsFrame)
if !ok {
// this kicks off resetTransport, so must be last before return
t.Close(connectionErrorf(true, nil, "initial http2 frame from server is not a settings frame: %T", frame))
return connectionErrorf(true, nil, "initial http2 frame from server is not a settings frame: %T", frame)
}
t.handleSettings(sf, true)
return nil
}
// reader verifies the server preface and reads all subsequent data from
// network connection. If the server preface is not read successfully, an
// error is pushed to errCh; otherwise errCh is closed with no error.
func (t *http2Client) reader(errCh chan<- error) {
defer close(t.readerDone)
if err := t.readServerPreface(); err != nil {
errCh <- err
return
}
t.onPrefaceReceipt()
t.handleSettings(sf, true)
close(errCh)
if t.keepaliveEnabled {
atomic.StoreInt64(&t.lastRead, time.Now().UnixNano())
}
// loop to keep reading incoming messages on this transport.
for {

View File

@ -573,8 +573,8 @@ type ConnectOptions struct {
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onPrefaceReceipt func(), onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) {
return newHTTP2Client(connectCtx, ctx, addr, opts, onPrefaceReceipt, onGoAway, onClose)
func NewClientTransport(connectCtx, ctx context.Context, addr resolver.Address, opts ConnectOptions, onGoAway func(GoAwayReason), onClose func()) (ClientTransport, error) {
return newHTTP2Client(connectCtx, ctx, addr, opts, onGoAway, onClose)
}
// Options provides additional hints and information for message

View File

@ -452,7 +452,7 @@ func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts
copts.ChannelzParentID = channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func() {}, func(GoAwayReason) {}, func() {})
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {})
if connErr != nil {
cancel() // Do not cancel in success path.
t.Fatalf("failed to create transport: %v", connErr)
@ -474,10 +474,16 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, connCh chan net.C
close(connCh)
return
}
framer := http2.NewFramer(conn, conn)
if err := framer.WriteSettings(); err != nil {
t.Errorf("Error at server-side while writing settings: %v", err)
close(connCh)
return
}
connCh <- conn
}()
connectCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(2*time.Second))
tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func() {}, func(GoAwayReason) {}, func() {})
tr, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
if err != nil {
cancel() // Do not cancel in success path.
// Server clean-up.
@ -1248,6 +1254,59 @@ func (s) TestServerWithMisbehavedClient(t *testing.T) {
}
}
func (s) TestClientHonorsConnectContext(t *testing.T) {
// Create a server that will not send a preface.
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Error while listening: %v", err)
}
defer lis.Close()
go func() { // Launch the misbehaving server.
sconn, err := lis.Accept()
if err != nil {
t.Errorf("Error while accepting: %v", err)
return
}
defer sconn.Close()
if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
t.Errorf("Error while reading client preface: %v", err)
return
}
sfr := http2.NewFramer(sconn, sconn)
// Do not write a settings frame, but read from the conn forever.
for {
if _, err := sfr.ReadFrame(); err != nil {
return
}
}
}()
// Test context cancelation.
timeBefore := time.Now()
connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
time.AfterFunc(100*time.Millisecond, cancel)
copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)}
_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
if err == nil {
t.Fatalf("NewClientTransport() returned successfully; wanted error")
}
t.Logf("NewClientTransport() = _, %v", err)
if time.Now().Sub(timeBefore) > 3*time.Second {
t.Fatalf("NewClientTransport returned > 2.9s after context cancelation")
}
// Test context deadline.
timeBefore = time.Now()
connectCtx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
if err == nil {
t.Fatalf("NewClientTransport() returned successfully; wanted error")
}
t.Logf("NewClientTransport() = _, %v", err)
}
func (s) TestClientWithMisbehavedServer(t *testing.T) {
// Create a misbehaving server.
lis, err := net.Listen("tcp", "localhost:0")
@ -1266,10 +1325,14 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
}
defer sconn.Close()
if _, err := io.ReadFull(sconn, make([]byte, len(clientPreface))); err != nil {
t.Errorf("Error while reading clieng preface: %v", err)
t.Errorf("Error while reading client preface: %v", err)
return
}
sfr := http2.NewFramer(sconn, sconn)
if err := sfr.WriteSettings(); err != nil {
t.Errorf("Error while writing settings: %v", err)
return
}
if err := sfr.WriteSettingsAck(); err != nil {
t.Errorf("Error while writing settings: %v", err)
return
@ -1316,7 +1379,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
defer cancel()
copts := ConnectOptions{ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil)}
ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func() {}, func(GoAwayReason) {}, func() {})
ct, err := NewClientTransport(connectCtx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {}, func() {})
if err != nil {
t.Fatalf("Error while creating client transport: %v", err)
}
@ -2217,7 +2280,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
TransportCredentials: creds,
ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil),
}
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func() {}, func(GoAwayReason) {}, func() {})
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {})
if err != nil {
t.Fatalf("NewClientTransport(): %v", err)
}
@ -2258,7 +2321,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
Dialer: dialer,
ChannelzParentID: channelz.NewIdentifierForTesting(channelz.RefSubChannel, time.Now().Unix(), nil),
}
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func() {}, func(GoAwayReason) {}, func() {})
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}, func() {})
if err != nil {
t.Fatalf("NewClientTransport(): %v", err)
}

View File

@ -7798,92 +7798,6 @@ func (s) TestClientSettingsFloodCloseConn(t *testing.T) {
timer.Stop()
}
// TestDeadlineSetOnConnectionOnClientCredentialHandshake tests that there is a deadline
// set on the net.Conn when a credential handshake happens in http2_client.
func (s) TestDeadlineSetOnConnectionOnClientCredentialHandshake(t *testing.T) {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
connCh := make(chan net.Conn, 1)
go func() {
defer close(connCh)
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error accepting connection: %v", err)
return
}
connCh <- conn
}()
defer func() {
conn := <-connCh
if conn != nil {
conn.Close()
}
}()
deadlineCh := testutils.NewChannel()
cvd := &credentialsVerifyDeadline{
deadlineCh: deadlineCh,
}
dOpt := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
return &infoConn{Conn: conn}, nil
})
cc, err := grpc.Dial(lis.Addr().String(), dOpt, grpc.WithTransportCredentials(cvd))
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
deadline, err := deadlineCh.Receive(ctx)
if err != nil {
t.Fatalf("Error receiving from credsInvoked: %v", err)
}
// Default connection timeout is 20 seconds, so if the deadline exceeds now
// + 18 seconds it should be valid.
if !deadline.(time.Time).After(time.Now().Add(time.Second * 18)) {
t.Fatalf("Connection did not have deadline set.")
}
}
type infoConn struct {
net.Conn
deadline time.Time
}
func (c *infoConn) SetDeadline(t time.Time) error {
c.deadline = t
return c.Conn.SetDeadline(t)
}
type credentialsVerifyDeadline struct {
deadlineCh *testutils.Channel
}
func (cvd *credentialsVerifyDeadline) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
return rawConn, nil, nil
}
func (cvd *credentialsVerifyDeadline) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
cvd.deadlineCh.Send(rawConn.(*infoConn).deadline)
return rawConn, nil, nil
}
func (cvd *credentialsVerifyDeadline) Info() credentials.ProtocolInfo {
return credentials.ProtocolInfo{}
}
func (cvd *credentialsVerifyDeadline) Clone() credentials.TransportCredentials {
return cvd
}
func (cvd *credentialsVerifyDeadline) OverrideServerName(s string) error {
return nil
}
func unaryInterceptorVerifyConn(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
conn := transport.GetConnection(ctx)
if conn == nil {