diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 119f01e3e..5b2493130 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -241,7 +241,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // and passed to the credential handshaker. This makes it possible for // address specific arbitrary data to reach the credential handshaker. connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) + rawConn := conn + // Pull the deadline from the connectCtx, which will be used for + // timeouts in the authentication protocol handshake. Can ignore the + // boolean as the deadline will return the zero value, which will make + // the conn not timeout on I/O operations. + deadline, _ := connectCtx.Deadline() + rawConn.SetDeadline(deadline) + conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, rawConn) + rawConn.SetDeadline(time.Time{}) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 552f74e1b..1b839529c 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -7668,3 +7668,89 @@ func (s) TestClientSettingsFloodCloseConn(t *testing.T) { s.GracefulStop() timer.Stop() } + +// TestDeadlineSetOnConnectionOnClientCredentialHandshake tests that there is a deadline +// set on the net.Conn when a credential handshake happens in http2_client. +func (s) TestDeadlineSetOnConnectionOnClientCredentialHandshake(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + connCh := make(chan net.Conn, 1) + go func() { + defer close(connCh) + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + defer func() { + conn := <-connCh + if conn != nil { + conn.Close() + } + }() + deadlineCh := testutils.NewChannel() + cvd := &credentialsVerifyDeadline{ + deadlineCh: deadlineCh, + } + dOpt := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return &infoConn{Conn: conn}, nil + }) + cc, err := grpc.Dial(lis.Addr().String(), dOpt, grpc.WithTransportCredentials(cvd)) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + deadline, err := deadlineCh.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from credsInvoked: %v", err) + } + // Default connection timeout is 20 seconds, so if the deadline exceeds now + // + 18 seconds it should be valid. + if !deadline.(time.Time).After(time.Now().Add(time.Second * 18)) { + t.Fatalf("Connection did not have deadline set.") + } +} + +type infoConn struct { + net.Conn + deadline time.Time +} + +func (c *infoConn) SetDeadline(t time.Time) error { + c.deadline = t + return c.Conn.SetDeadline(t) +} + +type credentialsVerifyDeadline struct { + deadlineCh *testutils.Channel +} + +func (cvd *credentialsVerifyDeadline) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + cvd.deadlineCh.Send(rawConn.(*infoConn).deadline) + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (cvd *credentialsVerifyDeadline) Clone() credentials.TransportCredentials { + return cvd +} +func (cvd *credentialsVerifyDeadline) OverrideServerName(s string) error { + return nil +}