mirror of https://github.com/grpc/grpc-go.git
				
				
				
			credentials/tls: reject connections with ALPN disabled (#7184)
This commit is contained in:
		
							parent
							
								
									0a0abfadb7
								
							
						
					
					
						commit
						48b6b11b38
					
				|  | @ -27,9 +27,13 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
| 
 | 
 | ||||||
|  | 	"google.golang.org/grpc/grpclog" | ||||||
| 	credinternal "google.golang.org/grpc/internal/credentials" | 	credinternal "google.golang.org/grpc/internal/credentials" | ||||||
|  | 	"google.golang.org/grpc/internal/envconfig" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var logger = grpclog.Component("credentials") | ||||||
|  | 
 | ||||||
| // TLSInfo contains the auth information for a TLS authenticated connection.
 | // TLSInfo contains the auth information for a TLS authenticated connection.
 | ||||||
| // It implements the AuthInfo interface.
 | // It implements the AuthInfo interface.
 | ||||||
| type TLSInfo struct { | type TLSInfo struct { | ||||||
|  | @ -112,6 +116,22 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon | ||||||
| 		conn.Close() | 		conn.Close() | ||||||
| 		return nil, nil, ctx.Err() | 		return nil, nil, ctx.Err() | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	// The negotiated protocol can be either of the following:
 | ||||||
|  | 	// 1. h2: When the server supports ALPN. Only HTTP/2 can be negotiated since
 | ||||||
|  | 	//    it is the only protocol advertised by the client during the handshake.
 | ||||||
|  | 	//    The tls library ensures that the server chooses a protocol advertised
 | ||||||
|  | 	//    by the client.
 | ||||||
|  | 	// 2. "" (empty string): If the server doesn't support ALPN. ALPN is a requirement
 | ||||||
|  | 	//    for using HTTP/2 over TLS. We can terminate the connection immediately.
 | ||||||
|  | 	np := conn.ConnectionState().NegotiatedProtocol | ||||||
|  | 	if np == "" { | ||||||
|  | 		if envconfig.EnforceALPNEnabled { | ||||||
|  | 			conn.Close() | ||||||
|  | 			return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property") | ||||||
|  | 		} | ||||||
|  | 		logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName) | ||||||
|  | 	} | ||||||
| 	tlsInfo := TLSInfo{ | 	tlsInfo := TLSInfo{ | ||||||
| 		State: conn.ConnectionState(), | 		State: conn.ConnectionState(), | ||||||
| 		CommonAuthInfo: CommonAuthInfo{ | 		CommonAuthInfo: CommonAuthInfo{ | ||||||
|  | @ -131,8 +151,20 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) | ||||||
| 		conn.Close() | 		conn.Close() | ||||||
| 		return nil, nil, err | 		return nil, nil, err | ||||||
| 	} | 	} | ||||||
|  | 	cs := conn.ConnectionState() | ||||||
|  | 	// The negotiated application protocol can be empty only if the client doesn't
 | ||||||
|  | 	// support ALPN. In such cases, we can close the connection since ALPN is required
 | ||||||
|  | 	// for using HTTP/2 over TLS.
 | ||||||
|  | 	if cs.NegotiatedProtocol == "" { | ||||||
|  | 		if envconfig.EnforceALPNEnabled { | ||||||
|  | 			conn.Close() | ||||||
|  | 			return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property") | ||||||
|  | 		} else if logger.V(2) { | ||||||
|  | 			logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	tlsInfo := TLSInfo{ | 	tlsInfo := TLSInfo{ | ||||||
| 		State: conn.ConnectionState(), | 		State: cs, | ||||||
| 		CommonAuthInfo: CommonAuthInfo{ | 		CommonAuthInfo: CommonAuthInfo{ | ||||||
| 			SecurityLevel: PrivacyAndIntegrity, | 			SecurityLevel: PrivacyAndIntegrity, | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | @ -23,6 +23,7 @@ import ( | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | @ -31,6 +32,7 @@ import ( | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| 	"google.golang.org/grpc/codes" | 	"google.golang.org/grpc/codes" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
|  | 	"google.golang.org/grpc/internal/envconfig" | ||||||
| 	"google.golang.org/grpc/internal/grpctest" | 	"google.golang.org/grpc/internal/grpctest" | ||||||
| 	"google.golang.org/grpc/internal/stubserver" | 	"google.golang.org/grpc/internal/stubserver" | ||||||
| 	"google.golang.org/grpc/status" | 	"google.golang.org/grpc/status" | ||||||
|  | @ -236,3 +238,160 @@ func (s) TestTLS_CipherSuitesOverridable(t *testing.T) { | ||||||
| 		t.Fatalf("EmptyCall err = %v; want <nil>", err) | 		t.Fatalf("EmptyCall err = %v; want <nil>", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
 | ||||||
|  | // connecting to a server that doesn't support ALPN.
 | ||||||
|  | func (s) TestTLS_DisabledALPNClient(t *testing.T) { | ||||||
|  | 	initialVal := envconfig.EnforceALPNEnabled | ||||||
|  | 	defer func() { | ||||||
|  | 		envconfig.EnforceALPNEnabled = initialVal | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name         string | ||||||
|  | 		alpnEnforced bool | ||||||
|  | 		wantErr      bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:         "enforced", | ||||||
|  | 			alpnEnforced: true, | ||||||
|  | 			wantErr:      true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "not_enforced", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			envconfig.EnforceALPNEnabled = tc.alpnEnforced | ||||||
|  | 
 | ||||||
|  | 			listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{ | ||||||
|  | 				Certificates: []tls.Certificate{serverCert}, | ||||||
|  | 				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
 | ||||||
|  | 			}) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatalf("Error starting TLS server: %v", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			errCh := make(chan error, 1) | ||||||
|  | 			go func() { | ||||||
|  | 				conn, err := listener.Accept() | ||||||
|  | 				if err != nil { | ||||||
|  | 					errCh <- fmt.Errorf("listener.Accept returned error: %v", err) | ||||||
|  | 				} else { | ||||||
|  | 					// The first write to the TLS listener initiates the TLS handshake.
 | ||||||
|  | 					conn.Write([]byte("Hello, World!")) | ||||||
|  | 					conn.Close() | ||||||
|  | 				} | ||||||
|  | 				close(errCh) | ||||||
|  | 			}() | ||||||
|  | 
 | ||||||
|  | 			serverAddr := listener.Addr().String() | ||||||
|  | 			conn, err := net.Dial("tcp", serverAddr) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err) | ||||||
|  | 			} | ||||||
|  | 			defer conn.Close() | ||||||
|  | 
 | ||||||
|  | 			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||||||
|  | 			defer cancel() | ||||||
|  | 
 | ||||||
|  | 			clientCfg := tls.Config{ | ||||||
|  | 				ServerName: serverName, | ||||||
|  | 				RootCAs:    certPool, | ||||||
|  | 				NextProtos: []string{"h2"}, | ||||||
|  | 			} | ||||||
|  | 			_, _, err = credentials.NewTLS(&clientCfg).ClientHandshake(ctx, serverName, conn) | ||||||
|  | 
 | ||||||
|  | 			if gotErr := (err != nil); gotErr != tc.wantErr { | ||||||
|  | 				t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			select { | ||||||
|  | 			case err := <-errCh: | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("Unexpected error received from server: %v", err) | ||||||
|  | 				} | ||||||
|  | 			case <-ctx.Done(): | ||||||
|  | 				t.Fatalf("Timeout waiting for error from server") | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
 | ||||||
|  | // accepting a request from a client that doesn't support ALPN.
 | ||||||
|  | func (s) TestTLS_DisabledALPNServer(t *testing.T) { | ||||||
|  | 	initialVal := envconfig.EnforceALPNEnabled | ||||||
|  | 	defer func() { | ||||||
|  | 		envconfig.EnforceALPNEnabled = initialVal | ||||||
|  | 	}() | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name         string | ||||||
|  | 		alpnEnforced bool | ||||||
|  | 		wantErr      bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:         "enforced", | ||||||
|  | 			alpnEnforced: true, | ||||||
|  | 			wantErr:      true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "not_enforced", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			envconfig.EnforceALPNEnabled = tc.alpnEnforced | ||||||
|  | 
 | ||||||
|  | 			listener, err := net.Listen("tcp", "localhost:0") | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatalf("Error starting server: %v", err) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			errCh := make(chan error, 1) | ||||||
|  | 			go func() { | ||||||
|  | 				conn, err := listener.Accept() | ||||||
|  | 				if err != nil { | ||||||
|  | 					errCh <- fmt.Errorf("listener.Accept returned error: %v", err) | ||||||
|  | 					return | ||||||
|  | 				} | ||||||
|  | 				defer conn.Close() | ||||||
|  | 				serverCfg := tls.Config{ | ||||||
|  | 					Certificates: []tls.Certificate{serverCert}, | ||||||
|  | 					NextProtos:   []string{"h2"}, | ||||||
|  | 				} | ||||||
|  | 				_, _, err = credentials.NewTLS(&serverCfg).ServerHandshake(conn) | ||||||
|  | 				if gotErr := (err != nil); gotErr != tc.wantErr { | ||||||
|  | 					t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr) | ||||||
|  | 				} | ||||||
|  | 				close(errCh) | ||||||
|  | 			}() | ||||||
|  | 
 | ||||||
|  | 			serverAddr := listener.Addr().String() | ||||||
|  | 			clientCfg := &tls.Config{ | ||||||
|  | 				Certificates: []tls.Certificate{serverCert}, | ||||||
|  | 				NextProtos:   []string{}, // Empty list indicates ALPN is disabled.
 | ||||||
|  | 				RootCAs:      certPool, | ||||||
|  | 				ServerName:   serverName, | ||||||
|  | 			} | ||||||
|  | 			conn, err := tls.Dial("tcp", serverAddr, clientCfg) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err) | ||||||
|  | 			} | ||||||
|  | 			defer conn.Close() | ||||||
|  | 
 | ||||||
|  | 			select { | ||||||
|  | 			case <-time.After(defaultTestTimeout): | ||||||
|  | 				t.Fatal("Timed out waiting for completion") | ||||||
|  | 			case err := <-errCh: | ||||||
|  | 				if err != nil { | ||||||
|  | 					t.Fatalf("Unexpected server error: %v", err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -40,6 +40,12 @@ var ( | ||||||
| 	// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
 | 	// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
 | ||||||
| 	// handshakes that can be performed.
 | 	// handshakes that can be performed.
 | ||||||
| 	ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100) | 	ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100) | ||||||
|  | 	// EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled
 | ||||||
|  | 	// should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this
 | ||||||
|  | 	// option is present for backward compatibility. This option may be overridden
 | ||||||
|  | 	// by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true"
 | ||||||
|  | 	// or "false".
 | ||||||
|  | 	EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", false) | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func boolFromEnv(envVar string, def bool) bool { | func boolFromEnv(envVar string, def bool) bool { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue