diff --git a/clientconn.go b/clientconn.go index 21541b630..61ac1a0df 100644 --- a/clientconn.go +++ b/clientconn.go @@ -431,13 +431,16 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * if cc.dopts.bs == nil { cc.dopts.bs = DefaultBackoffConfig } + cc.parsedTarget = parseTarget(cc.target) creds := cc.dopts.copts.TransportCredentials if creds != nil && creds.Info().ServerName != "" { cc.authority = creds.Info().ServerName } else if cc.dopts.insecure && cc.dopts.copts.Authority != "" { cc.authority = cc.dopts.copts.Authority } else { - cc.authority = target + // Use endpoint from "scheme://authority/endpoint" as the default + // authority for ClientConn. + cc.authority = cc.parsedTarget.Endpoint } if cc.dopts.scChan != nil && !scSet { @@ -541,10 +544,11 @@ type ClientConn struct { ctx context.Context cancel context.CancelFunc - target string - authority string - dopts dialOptions - csMgr *connectivityStateManager + target string + parsedTarget resolver.Target + authority string + dopts dialOptions + csMgr *connectivityStateManager customBalancer bool // If this is true, switching balancer will be disabled. balancerBuildOpts balancer.BuildOptions @@ -953,8 +957,9 @@ func (ac *addrConn) resetTransport() error { } ac.mu.Unlock() sinfo := transport.TargetInfo{ - Addr: addr.Addr, - Metadata: addr.Metadata, + Addr: addr.Addr, + Metadata: addr.Metadata, + Authority: ac.cc.authority, } newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout) if err != nil { diff --git a/credentials/credentials.go b/credentials/credentials.go index c6575b5c7..90b6a6117 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -127,15 +127,15 @@ func (c tlsCreds) Info() ProtocolInfo { } } -func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { +func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { // use local cfg to avoid clobbering ServerName if using multiple endpoints cfg := cloneTLSConfig(c.config) if cfg.ServerName == "" { - colonPos := strings.LastIndex(addr, ":") + colonPos := strings.LastIndex(authority, ":") if colonPos == -1 { - colonPos = len(addr) + colonPos = len(authority) } - cfg.ServerName = addr[:colonPos] + cfg.ServerName = authority[:colonPos] } conn := tls.Client(rawConn, cfg) errChannel := make(chan error, 1) diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index 7bb843efd..53fb77a6c 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -58,12 +58,11 @@ func parseTarget(target string) (ret resolver.Target) { // builder for this scheme. It then builds the resolver and starts the // monitoring goroutine for it. func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { - target := parseTarget(cc.target) - grpclog.Infof("dialing to target with scheme: %q", target.Scheme) + grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme) - rb := resolver.Get(target.Scheme) + rb := resolver.Get(cc.parsedTarget.Scheme) if rb == nil { - return nil, fmt.Errorf("could not get resolver for scheme: %q", target.Scheme) + return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme) } ccr := &ccResolverWrapper{ @@ -74,7 +73,7 @@ func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) { } var err error - ccr.resolver, err = rb.Build(target, ccr, resolver.BuildOption{}) + ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{}) if err != nil { return nil, err } diff --git a/test/end2end_test.go b/test/end2end_test.go index 6f0c0fa16..be9eb2755 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4287,6 +4287,69 @@ func TestServerCredsDispatch(t *testing.T) { } } +type authorityCheckCreds struct { + got string +} + +func (c *authorityCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} +func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + c.got = authority + return rawConn, nil, nil +} +func (c *authorityCheckCreds) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (c *authorityCheckCreds) Clone() credentials.TransportCredentials { + return c +} +func (c *authorityCheckCreds) OverrideServerName(s string) error { + return nil +} + +// This test makes sure that the authority client handshake gets is the endpoint +// in dial target, not the resolved ip address. +func TestCredsHandshakeAuthority(t *testing.T) { + const testAuthority = "test.auth.ori.ty" + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + cred := &authorityCheckCreds{} + s := grpc.NewServer() + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + for { + s := cc.GetState() + if s == connectivity.Ready { + break + } + if !cc.WaitForStateChange(ctx, s) { + // ctx got timeout or canceled. + t.Fatalf("ClientConn is not ready after 100 ms") + } + } + + if cred.got != testAuthority { + t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) + } +} + func TestFlowControlLogicalRace(t *testing.T) { // Test for a regression of https://github.com/grpc/grpc-go/issues/632, // and other flow control bugs. diff --git a/transport/http2_client.go b/transport/http2_client.go index 5985666bd..f665ef047 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -44,7 +44,6 @@ import ( type http2Client struct { ctx context.Context cancel context.CancelFunc - target string // server name/addr userAgent string md interface{} conn net.Conn // underlying communication channel @@ -175,7 +174,7 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t ) if creds := opts.TransportCredentials; creds != nil { scheme = "https" - conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn) + conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn) if err != nil { // Credentials handshake errors are typically considered permanent // to avoid retrying on e.g. bad certificates. @@ -210,7 +209,6 @@ func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions, t t := &http2Client{ ctx: ctx, cancel: cancel, - target: addr.Addr, userAgent: opts.UserAgent, md: addr.Metadata, conn: conn, diff --git a/transport/transport.go b/transport/transport.go index 44d48a99a..e0651c4d2 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -513,8 +513,9 @@ type ConnectOptions struct { // TargetInfo contains the information of the target such as network address and metadata. type TargetInfo struct { - Addr string - Metadata interface{} + Addr string + Metadata interface{} + Authority string } // NewClientTransport establishes the transport with the required ConnectOptions