mirror of https://github.com/grpc/grpc-go.git
credentials: Add experimental credentials that don't enforce ALPN (#7980)
This commit is contained in:
parent
130c1d73d0
commit
eb1added1d
|
@ -32,6 +32,8 @@ import (
|
||||||
"google.golang.org/grpc/internal/envconfig"
|
"google.golang.org/grpc/internal/envconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const alpnFailureHelpMessage = "If you upgraded from a grpc-go version earlier than 1.67, your TLS connections may have stopped working due to ALPN enforcement. For more details, see: https://github.com/grpc/grpc-go/issues/434"
|
||||||
|
|
||||||
var logger = grpclog.Component("credentials")
|
var logger = grpclog.Component("credentials")
|
||||||
|
|
||||||
// TLSInfo contains the auth information for a TLS authenticated connection.
|
// TLSInfo contains the auth information for a TLS authenticated connection.
|
||||||
|
@ -128,7 +130,7 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon
|
||||||
if np == "" {
|
if np == "" {
|
||||||
if envconfig.EnforceALPNEnabled {
|
if envconfig.EnforceALPNEnabled {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
|
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
|
||||||
}
|
}
|
||||||
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
|
logger.Warningf("Allowing TLS connection to server %q with ALPN disabled. TLS connections to servers with ALPN disabled will be disallowed in future grpc-go releases", cfg.ServerName)
|
||||||
}
|
}
|
||||||
|
@ -158,7 +160,7 @@ func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error)
|
||||||
if cs.NegotiatedProtocol == "" {
|
if cs.NegotiatedProtocol == "" {
|
||||||
if envconfig.EnforceALPNEnabled {
|
if envconfig.EnforceALPNEnabled {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property")
|
return nil, nil, fmt.Errorf("credentials: cannot check peer: missing selected ALPN property. %s", alpnFailureHelpMessage)
|
||||||
} else if logger.V(2) {
|
} else if logger.V(2) {
|
||||||
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
|
logger.Info("Allowing TLS connection from client with ALPN disabled. TLS connections with ALPN disabled will be disallowed in future grpc-go releases")
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,252 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
|
"google.golang.org/grpc/testdata"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTestTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
type s struct {
|
||||||
|
grpctest.Tester
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
grpctest.RunSubTests(t, s{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestTLSOverrideServerName(t *testing.T) {
|
||||||
|
expectedServerName := "server.name"
|
||||||
|
c := NewTLSWithALPNDisabled(nil)
|
||||||
|
c.OverrideServerName(expectedServerName)
|
||||||
|
if c.Info().ServerName != expectedServerName {
|
||||||
|
t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestTLSClone(t *testing.T) {
|
||||||
|
expectedServerName := "server.name"
|
||||||
|
c := NewTLSWithALPNDisabled(nil)
|
||||||
|
c.OverrideServerName(expectedServerName)
|
||||||
|
cc := c.Clone()
|
||||||
|
if cc.Info().ServerName != expectedServerName {
|
||||||
|
t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName)
|
||||||
|
}
|
||||||
|
cc.OverrideServerName("")
|
||||||
|
if c.Info().ServerName != expectedServerName {
|
||||||
|
t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverHandshake func(net.Conn) (credentials.AuthInfo, error)
|
||||||
|
|
||||||
|
func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) {
|
||||||
|
tcs := []struct {
|
||||||
|
name string
|
||||||
|
address string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "localhost",
|
||||||
|
address: "localhost:0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4",
|
||||||
|
address: "127.0.0.1:0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6",
|
||||||
|
address: "[::1]:0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
done := make(chan credentials.AuthInfo, 1)
|
||||||
|
lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address)
|
||||||
|
defer lis.Close()
|
||||||
|
lisAddr := lis.Addr().String()
|
||||||
|
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr)
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) {
|
||||||
|
done := make(chan credentials.AuthInfo, 1)
|
||||||
|
lis := launchServer(t, gRPCServerHandshake, done)
|
||||||
|
defer lis.Close()
|
||||||
|
clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String())
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestServerAndClientHandshake(t *testing.T) {
|
||||||
|
done := make(chan credentials.AuthInfo, 1)
|
||||||
|
lis := launchServer(t, gRPCServerHandshake, done)
|
||||||
|
defer lis.Close()
|
||||||
|
clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String())
|
||||||
|
// wait until server sends serverAuthInfo or fails.
|
||||||
|
serverAuthInfo, ok := <-done
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Error at server-side")
|
||||||
|
}
|
||||||
|
if !compare(clientAuthInfo, serverAuthInfo) {
|
||||||
|
t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compare(a1, a2 credentials.AuthInfo) bool {
|
||||||
|
if a1.AuthType() != a2.AuthType() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch a1.AuthType() {
|
||||||
|
case "tls":
|
||||||
|
state1 := a1.(credentials.TLSInfo).State
|
||||||
|
state2 := a2.(credentials.TLSInfo).State
|
||||||
|
if state1.Version == state2.Version &&
|
||||||
|
state1.HandshakeComplete == state2.HandshakeComplete &&
|
||||||
|
state1.CipherSuite == state2.CipherSuite &&
|
||||||
|
state1.NegotiatedProtocol == state2.NegotiatedProtocol {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchServer(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo) net.Listener {
|
||||||
|
return launchServerOnListenAddress(t, hs, done, "localhost:0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, address string) net.Listener {
|
||||||
|
lis, err := net.Listen("tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
|
||||||
|
strings.Contains(err.Error(), "socket: address family not supported by protocol") {
|
||||||
|
t.Skipf("no support for address %v", address)
|
||||||
|
}
|
||||||
|
t.Fatalf("Failed to listen: %v", err)
|
||||||
|
}
|
||||||
|
go serverHandle(t, hs, done, lis)
|
||||||
|
return lis
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is run in a separate goroutine.
|
||||||
|
func serverHandle(t *testing.T, hs serverHandshake, done chan credentials.AuthInfo, lis net.Listener) {
|
||||||
|
serverRawConn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Server failed to accept connection: %v", err)
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverAuthInfo, err := hs(serverRawConn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Server failed while handshake. Error: %v", err)
|
||||||
|
serverRawConn.Close()
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
done <- serverAuthInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), lisAddr string) credentials.AuthInfo {
|
||||||
|
conn, err := net.Dial("tcp", lisAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
clientAuthInfo, err := hs(conn, lisAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error on client while handshake. Error: %v", err)
|
||||||
|
}
|
||||||
|
return clientAuthInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server handshake implementation in gRPC.
|
||||||
|
func gRPCServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {
|
||||||
|
serverTLS, err := NewServerTLSFromFileWithALPNDisabled(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, serverAuthInfo, err := serverTLS.ServerHandshake(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return serverAuthInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client handshake implementation in gRPC.
|
||||||
|
func gRPCClientHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
|
||||||
|
clientTLS := NewTLSWithALPNDisabled(&tls.Config{InsecureSkipVerify: true})
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
_, authInfo, err := clientTLS.ClientHandshake(ctx, lisAddr, conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return authInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tlsServerHandshake(conn net.Conn) (credentials.AuthInfo, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||||
|
serverConn := tls.Server(conn, serverTLSConfig)
|
||||||
|
err = serverConn.Handshake()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return credentials.TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func tlsClientHandshake(conn net.Conn, _ string) (credentials.AuthInfo, error) {
|
||||||
|
clientTLSConfig := &tls.Config{InsecureSkipVerify: true}
|
||||||
|
clientConn := tls.Client(conn, clientTLSConfig)
|
||||||
|
if err := clientConn.Handshake(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return credentials.TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}, nil
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package internal defines APIs for parsing SPIFFE ID.
|
||||||
|
//
|
||||||
|
// All APIs in this package are experimental.
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/grpclog"
|
||||||
|
)
|
||||||
|
|
||||||
|
var logger = grpclog.Component("credentials")
|
||||||
|
|
||||||
|
// SPIFFEIDFromState parses the SPIFFE ID from State. If the SPIFFE ID format
|
||||||
|
// is invalid, return nil with warning.
|
||||||
|
func SPIFFEIDFromState(state tls.ConnectionState) *url.URL {
|
||||||
|
if len(state.PeerCertificates) == 0 || len(state.PeerCertificates[0].URIs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return SPIFFEIDFromCert(state.PeerCertificates[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// SPIFFEIDFromCert parses the SPIFFE ID from x509.Certificate. If the SPIFFE
|
||||||
|
// ID format is invalid, return nil with warning.
|
||||||
|
func SPIFFEIDFromCert(cert *x509.Certificate) *url.URL {
|
||||||
|
if cert == nil || cert.URIs == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var spiffeID *url.URL
|
||||||
|
for _, uri := range cert.URIs {
|
||||||
|
if uri == nil || uri.Scheme != "spiffe" || uri.Opaque != "" || (uri.User != nil && uri.User.Username() != "") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// From this point, we assume the uri is intended for a SPIFFE ID.
|
||||||
|
if len(uri.String()) > 2048 {
|
||||||
|
logger.Warning("invalid SPIFFE ID: total ID length larger than 2048 bytes")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(uri.Host) == 0 || len(uri.Path) == 0 {
|
||||||
|
logger.Warning("invalid SPIFFE ID: domain or workload ID is empty")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(uri.Host) > 255 {
|
||||||
|
logger.Warning("invalid SPIFFE ID: domain length larger than 255 characters")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// A valid SPIFFE certificate can only have exactly one URI SAN field.
|
||||||
|
if len(cert.URIs) > 1 {
|
||||||
|
logger.Warning("invalid SPIFFE ID: multiple URI SANs")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
spiffeID = uri
|
||||||
|
}
|
||||||
|
return spiffeID
|
||||||
|
}
|
|
@ -0,0 +1,233 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
|
"google.golang.org/grpc/testdata"
|
||||||
|
)
|
||||||
|
|
||||||
|
const wantURI = "spiffe://foo.bar.com/client/workload/1"
|
||||||
|
|
||||||
|
type s struct {
|
||||||
|
grpctest.Tester
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
grpctest.RunSubTests(t, s{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestSPIFFEIDFromState(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urls []*url.URL
|
||||||
|
// If we expect a SPIFFE ID to be returned.
|
||||||
|
wantID bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty URIs",
|
||||||
|
urls: []*url.URL{},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "good SPIFFE ID",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid host",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid path",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "",
|
||||||
|
RawPath: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large path",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: string(make([]byte, 2050)),
|
||||||
|
RawPath: string(make([]byte, 2050)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large host",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: string(make([]byte, 256)),
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple URI SANs",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "bar.baz.com",
|
||||||
|
Path: "workload/wl2",
|
||||||
|
RawPath: "workload/wl2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple URI SANs without SPIFFE ID",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Scheme: "ssh",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple URI SANs with one SPIFFE ID",
|
||||||
|
urls: []*url.URL{
|
||||||
|
{
|
||||||
|
Scheme: "spiffe",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "foo.bar.com",
|
||||||
|
Path: "workload/wl1",
|
||||||
|
RawPath: "workload/wl1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
state := tls.ConnectionState{PeerCertificates: []*x509.Certificate{{URIs: tt.urls}}}
|
||||||
|
id := SPIFFEIDFromState(state)
|
||||||
|
if got, want := id != nil, tt.wantID; got != want {
|
||||||
|
t.Errorf("want wantID = %v, but SPIFFE ID is %v", want, id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestSPIFFEIDFromCert(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataPath string
|
||||||
|
// If we expect a SPIFFE ID to be returned.
|
||||||
|
wantID bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "good certificate with SPIFFE ID",
|
||||||
|
dataPath: "x509/spiffe_cert.pem",
|
||||||
|
wantID: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad certificate with SPIFFE ID and another URI",
|
||||||
|
dataPath: "x509/multiple_uri_cert.pem",
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "certificate without SPIFFE ID",
|
||||||
|
dataPath: "x509/client1_cert.pem",
|
||||||
|
wantID: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data, err := os.ReadFile(testdata.Path(tt.dataPath))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("os.ReadFile(%s) failed: %v", testdata.Path(tt.dataPath), err)
|
||||||
|
}
|
||||||
|
block, _ := pem.Decode(data)
|
||||||
|
if block == nil {
|
||||||
|
t.Fatalf("Failed to parse the certificate: byte block is nil")
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x509.ParseCertificate(%b) failed: %v", block.Bytes, err)
|
||||||
|
}
|
||||||
|
uri := SPIFFEIDFromCert(cert)
|
||||||
|
if (uri != nil) != tt.wantID {
|
||||||
|
t.Fatalf("wantID got and want mismatch, got %t, want %t", uri != nil, tt.wantID)
|
||||||
|
}
|
||||||
|
if uri != nil && uri.String() != wantURI {
|
||||||
|
t.Fatalf("SPIFFE ID not expected, got %s, want %s", uri.String(), wantURI)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sysConn = syscall.Conn
|
||||||
|
|
||||||
|
// syscallConn keeps reference of rawConn to support syscall.Conn for channelz.
|
||||||
|
// SyscallConn() (the method in interface syscall.Conn) is explicitly
|
||||||
|
// implemented on this type,
|
||||||
|
//
|
||||||
|
// Interface syscall.Conn is implemented by most net.Conn implementations (e.g.
|
||||||
|
// TCPConn, UnixConn), but is not part of net.Conn interface. So wrapper conns
|
||||||
|
// that embed net.Conn don't implement syscall.Conn. (Side note: tls.Conn
|
||||||
|
// doesn't embed net.Conn, so even if syscall.Conn is part of net.Conn, it won't
|
||||||
|
// help here).
|
||||||
|
type syscallConn struct {
|
||||||
|
net.Conn
|
||||||
|
// sysConn is a type alias of syscall.Conn. It's necessary because the name
|
||||||
|
// `Conn` collides with `net.Conn`.
|
||||||
|
sysConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapSyscallConn tries to wrap rawConn and newConn into a net.Conn that
|
||||||
|
// implements syscall.Conn. rawConn will be used to support syscall, and newConn
|
||||||
|
// will be used for read/write.
|
||||||
|
//
|
||||||
|
// This function returns newConn if rawConn doesn't implement syscall.Conn.
|
||||||
|
func WrapSyscallConn(rawConn, newConn net.Conn) net.Conn {
|
||||||
|
sysConn, ok := rawConn.(syscall.Conn)
|
||||||
|
if !ok {
|
||||||
|
return newConn
|
||||||
|
}
|
||||||
|
return &syscallConn{
|
||||||
|
Conn: newConn,
|
||||||
|
sysConn: sysConn,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (*syscallConn) SyscallConn() (syscall.RawConn, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type nonSyscallConn struct {
|
||||||
|
net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestWrapSyscallConn(t *testing.T) {
|
||||||
|
sc := &syscallConn{}
|
||||||
|
nsc := &nonSyscallConn{}
|
||||||
|
|
||||||
|
wrapConn := WrapSyscallConn(sc, nsc)
|
||||||
|
if _, ok := wrapConn.(syscall.Conn); !ok {
|
||||||
|
t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement", wrapConn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s) TestWrapSyscallConnNoWrap(t *testing.T) {
|
||||||
|
nscRaw := &nonSyscallConn{}
|
||||||
|
nsc := &nonSyscallConn{}
|
||||||
|
|
||||||
|
wrapConn := WrapSyscallConn(nscRaw, nsc)
|
||||||
|
if _, ok := wrapConn.(syscall.Conn); ok {
|
||||||
|
t.Errorf("returned conn (type %T) implements syscall.Conn, want not implement", wrapConn)
|
||||||
|
}
|
||||||
|
if wrapConn != nsc {
|
||||||
|
t.Errorf("returned conn is %p, want %p (the passed-in newConn)", wrapConn, nsc)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,249 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package credentials provides experimental TLS credentials.
|
||||||
|
// The use of this package is strongly discouraged. These credentials exist
|
||||||
|
// solely to maintain compatibility for users interacting with clients that
|
||||||
|
// violate the HTTP/2 specification by not advertising support for "h2" in ALPN.
|
||||||
|
// This package is slated for removal in upcoming grpc-go releases. Users must
|
||||||
|
// not rely on this package directly. Instead, they should either vendor a
|
||||||
|
// specific version of gRPC or copy the relevant credentials into their own
|
||||||
|
// codebase if absolutely necessary.
|
||||||
|
package credentials
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/experimental/credentials/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// tlsCreds is the credentials required for authenticating a connection using TLS.
|
||||||
|
type tlsCreds struct {
|
||||||
|
// TLS configuration
|
||||||
|
config *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c tlsCreds) Info() credentials.ProtocolInfo {
|
||||||
|
return credentials.ProtocolInfo{
|
||||||
|
SecurityProtocol: "tls",
|
||||||
|
SecurityVersion: "1.2",
|
||||||
|
ServerName: c.config.ServerName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ credentials.AuthInfo, err error) {
|
||||||
|
// use local cfg to avoid clobbering ServerName if using multiple endpoints
|
||||||
|
cfg := cloneTLSConfig(c.config)
|
||||||
|
if cfg.ServerName == "" {
|
||||||
|
serverName, _, err := net.SplitHostPort(authority)
|
||||||
|
if err != nil {
|
||||||
|
// If the authority had no host port or if the authority cannot be parsed, use it as-is.
|
||||||
|
serverName = authority
|
||||||
|
}
|
||||||
|
cfg.ServerName = serverName
|
||||||
|
}
|
||||||
|
conn := tls.Client(rawConn, cfg)
|
||||||
|
errChannel := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errChannel <- conn.Handshake()
|
||||||
|
close(errChannel)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <-errChannel:
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
conn.Close()
|
||||||
|
return nil, nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsInfo := credentials.TLSInfo{
|
||||||
|
State: conn.ConnectionState(),
|
||||||
|
CommonAuthInfo: credentials.CommonAuthInfo{
|
||||||
|
SecurityLevel: credentials.PrivacyAndIntegrity,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
id := internal.SPIFFEIDFromState(conn.ConnectionState())
|
||||||
|
if id != nil {
|
||||||
|
tlsInfo.SPIFFEID = id
|
||||||
|
}
|
||||||
|
return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||||
|
conn := tls.Server(rawConn, c.config)
|
||||||
|
if err := conn.Handshake(); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
cs := conn.ConnectionState()
|
||||||
|
tlsInfo := credentials.TLSInfo{
|
||||||
|
State: cs,
|
||||||
|
CommonAuthInfo: credentials.CommonAuthInfo{
|
||||||
|
SecurityLevel: credentials.PrivacyAndIntegrity,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
id := internal.SPIFFEIDFromState(conn.ConnectionState())
|
||||||
|
if id != nil {
|
||||||
|
tlsInfo.SPIFFEID = id
|
||||||
|
}
|
||||||
|
return internal.WrapSyscallConn(rawConn, conn), tlsInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tlsCreds) Clone() credentials.TransportCredentials {
|
||||||
|
return NewTLSWithALPNDisabled(c.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tlsCreds) OverrideServerName(serverNameOverride string) error {
|
||||||
|
c.config.ServerName = serverNameOverride
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The following cipher suites are forbidden for use with HTTP/2 by
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc7540#appendix-A
|
||||||
|
var tls12ForbiddenCipherSuites = map[uint16]struct{}{
|
||||||
|
tls.TLS_RSA_WITH_AES_128_CBC_SHA: {},
|
||||||
|
tls.TLS_RSA_WITH_AES_256_CBC_SHA: {},
|
||||||
|
tls.TLS_RSA_WITH_AES_128_GCM_SHA256: {},
|
||||||
|
tls.TLS_RSA_WITH_AES_256_GCM_SHA384: {},
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA: {},
|
||||||
|
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: {},
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: {},
|
||||||
|
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLSWithALPNDisabled uses c to construct a TransportCredentials based on
|
||||||
|
// TLS. ALPN verification is disabled.
|
||||||
|
func NewTLSWithALPNDisabled(c *tls.Config) credentials.TransportCredentials {
|
||||||
|
config := applyDefaults(c)
|
||||||
|
if config.GetConfigForClient != nil {
|
||||||
|
oldFn := config.GetConfigForClient
|
||||||
|
config.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
cfgForClient, err := oldFn(hello)
|
||||||
|
if err != nil || cfgForClient == nil {
|
||||||
|
return cfgForClient, err
|
||||||
|
}
|
||||||
|
return applyDefaults(cfgForClient), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &tlsCreds{config: config}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDefaults(c *tls.Config) *tls.Config {
|
||||||
|
config := cloneTLSConfig(c)
|
||||||
|
config.NextProtos = appendH2ToNextProtos(config.NextProtos)
|
||||||
|
// If the user did not configure a MinVersion and did not configure a
|
||||||
|
// MaxVersion < 1.2, use MinVersion=1.2, which is required by
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc7540#section-9.2
|
||||||
|
if config.MinVersion == 0 && (config.MaxVersion == 0 || config.MaxVersion >= tls.VersionTLS12) {
|
||||||
|
config.MinVersion = tls.VersionTLS12
|
||||||
|
}
|
||||||
|
// If the user did not configure CipherSuites, use all "secure" cipher
|
||||||
|
// suites reported by the TLS package, but remove some explicitly forbidden
|
||||||
|
// by https://datatracker.ietf.org/doc/html/rfc7540#appendix-A
|
||||||
|
if config.CipherSuites == nil {
|
||||||
|
for _, cs := range tls.CipherSuites() {
|
||||||
|
if _, ok := tls12ForbiddenCipherSuites[cs.ID]; !ok {
|
||||||
|
config.CipherSuites = append(config.CipherSuites, cs.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientTLSFromCertWithALPNDisabled constructs TLS credentials from the
|
||||||
|
// provided root certificate authority certificate(s) to validate server
|
||||||
|
// connections. If certificates to establish the identity of the client need to
|
||||||
|
// be included in the credentials (eg: for mTLS), use NewTLS instead, where a
|
||||||
|
// complete tls.Config can be specified.
|
||||||
|
// serverNameOverride is for testing only. If set to a non empty string,
|
||||||
|
// it will override the virtual host name of authority (e.g. :authority header
|
||||||
|
// field) in requests. ALPN verification is disabled.
|
||||||
|
func NewClientTLSFromCertWithALPNDisabled(cp *x509.CertPool, serverNameOverride string) credentials.TransportCredentials {
|
||||||
|
return NewTLSWithALPNDisabled(&tls.Config{ServerName: serverNameOverride, RootCAs: cp})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientTLSFromFileWithALPNDisabled constructs TLS credentials from the
|
||||||
|
// provided root certificate authority certificate file(s) to validate server
|
||||||
|
// connections. If certificates to establish the identity of the client need to
|
||||||
|
// be included in the credentials (eg: for mTLS), use NewTLS instead, where a
|
||||||
|
// complete tls.Config can be specified.
|
||||||
|
// serverNameOverride is for testing only. If set to a non empty string,
|
||||||
|
// it will override the virtual host name of authority (e.g. :authority header
|
||||||
|
// field) in requests. ALPN verification is disabled.
|
||||||
|
func NewClientTLSFromFileWithALPNDisabled(certFile, serverNameOverride string) (credentials.TransportCredentials, error) {
|
||||||
|
b, err := os.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cp := x509.NewCertPool()
|
||||||
|
if !cp.AppendCertsFromPEM(b) {
|
||||||
|
return nil, fmt.Errorf("credentials: failed to append certificates")
|
||||||
|
}
|
||||||
|
return NewTLSWithALPNDisabled(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerTLSFromCertWithALPNDisabled constructs TLS credentials from the
|
||||||
|
// input certificate for server. ALPN verification is disabled.
|
||||||
|
func NewServerTLSFromCertWithALPNDisabled(cert *tls.Certificate) credentials.TransportCredentials {
|
||||||
|
return NewTLSWithALPNDisabled(&tls.Config{Certificates: []tls.Certificate{*cert}})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerTLSFromFileWithALPNDisabled constructs TLS credentials from the
|
||||||
|
// input certificate file and key file for server. ALPN verification is disabled.
|
||||||
|
func NewServerTLSFromFileWithALPNDisabled(certFile, keyFile string) (credentials.TransportCredentials, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewTLSWithALPNDisabled(&tls.Config{Certificates: []tls.Certificate{cert}}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneTLSConfig returns a shallow clone of the exported
|
||||||
|
// fields of cfg, ignoring the unexported sync.Once, which
|
||||||
|
// contains a mutex and must not be copied.
|
||||||
|
//
|
||||||
|
// If cfg is nil, a new zero tls.Config is returned.
|
||||||
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendH2ToNextProtos appends h2 to next protos.
|
||||||
|
func appendH2ToNextProtos(ps []string) []string {
|
||||||
|
for _, p := range ps {
|
||||||
|
if p == http2.NextProtoTLS {
|
||||||
|
return ps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret := make([]string, 0, len(ps)+1)
|
||||||
|
ret = append(ret, ps...)
|
||||||
|
return append(ret, http2.NextProtoTLS)
|
||||||
|
}
|
|
@ -0,0 +1,604 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* Copyright 2025 gRPC authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package credentials_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
credsstable "google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/experimental/credentials"
|
||||||
|
"google.golang.org/grpc/internal/envconfig"
|
||||||
|
"google.golang.org/grpc/internal/grpctest"
|
||||||
|
"google.golang.org/grpc/internal/stubserver"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/grpc/testdata"
|
||||||
|
|
||||||
|
testgrpc "google.golang.org/grpc/interop/grpc_testing"
|
||||||
|
testpb "google.golang.org/grpc/interop/grpc_testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTestTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
type s struct {
|
||||||
|
grpctest.Tester
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
grpctest.RunSubTests(t, s{})
|
||||||
|
}
|
||||||
|
|
||||||
|
var serverCert tls.Certificate
|
||||||
|
var certPool *x509.CertPool
|
||||||
|
var serverName = "x.test.example.com"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
serverCert, err = tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("tls.LoadX509KeyPair(server1.pem, server1.key) failed: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := os.ReadFile(testdata.Path("x509/server_ca_cert.pem"))
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("Error reading CA cert file: %v", err))
|
||||||
|
}
|
||||||
|
certPool = x509.NewCertPool()
|
||||||
|
if !certPool.AppendCertsFromPEM(b) {
|
||||||
|
panic("Error appending cert from PEM")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that the MinVersion of tls.Config is set to 1.2 if it is not already
|
||||||
|
// set by the user.
|
||||||
|
func (s) TestTLS_MinVersion12(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
serverTLS func() *tls.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base_case",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
// MinVersion should be set to 1.2 by gRPC by default.
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback_to_base",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
config := &tls.Config{
|
||||||
|
// MinVersion should be set to 1.2 by gRPC by default.
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}
|
||||||
|
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dynamic_using_get_config_for_client",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{
|
||||||
|
// MinVersion should be set to 1.2 by gRPC by default.
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create server creds without a minimum version.
|
||||||
|
serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
|
||||||
|
ss := stubserver.StubServer{
|
||||||
|
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
return &testpb.Empty{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create client creds that supports V1.0-V1.1.
|
||||||
|
clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
RootCAs: certPool,
|
||||||
|
MinVersion: tls.VersionTLS10,
|
||||||
|
MaxVersion: tls.VersionTLS11,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start server and client separately, because Start() blocks on a
|
||||||
|
// successful connection, which we will not get.
|
||||||
|
if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
|
||||||
|
t.Fatalf("Error starting server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
cc, err := grpc.NewClient(ss.Address, grpc.WithTransportCredentials(clientCreds))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("grpc.NewClient error: %v", err)
|
||||||
|
}
|
||||||
|
defer cc.Close()
|
||||||
|
|
||||||
|
client := testgrpc.NewTestServiceClient(cc)
|
||||||
|
|
||||||
|
const wantStr = "authentication handshake failed"
|
||||||
|
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
|
||||||
|
t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that the MinVersion of tls.Config is not changed if it is set by the
|
||||||
|
// user.
|
||||||
|
func (s) TestTLS_MinVersionOverridable(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var allCipherSuites []uint16
|
||||||
|
for _, cs := range tls.CipherSuites() {
|
||||||
|
allCipherSuites = append(allCipherSuites, cs.ID)
|
||||||
|
}
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
serverTLS func() *tls.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base_case",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS10,
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
CipherSuites: allCipherSuites,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback_to_base",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
config := &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS10,
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
CipherSuites: allCipherSuites,
|
||||||
|
}
|
||||||
|
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dynamic_using_get_config_for_client",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS10,
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
CipherSuites: allCipherSuites,
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create server creds that allow v1.0.
|
||||||
|
serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
|
||||||
|
ss := stubserver.StubServer{
|
||||||
|
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
return &testpb.Empty{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create client creds that supports V1.0-V1.1.
|
||||||
|
clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
RootCAs: certPool,
|
||||||
|
CipherSuites: allCipherSuites,
|
||||||
|
MinVersion: tls.VersionTLS10,
|
||||||
|
MaxVersion: tls.VersionTLS11,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
|
||||||
|
t.Fatalf("Error starting stub server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||||
|
t.Fatalf("EmptyCall err = %v; want <nil>", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that CipherSuites is set to exclude HTTP/2 forbidden suites by default.
|
||||||
|
func (s) TestTLS_CipherSuites(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
serverTLS func() *tls.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base_case",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback_to_base",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
config := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}
|
||||||
|
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dynamic_using_get_config_for_client",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create server creds without cipher suites.
|
||||||
|
serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
|
||||||
|
ss := stubserver.StubServer{
|
||||||
|
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
return &testpb.Empty{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create client creds that use a forbidden suite only.
|
||||||
|
clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
RootCAs: certPool,
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
|
||||||
|
MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start server and client separately, because Start() blocks on a
|
||||||
|
// successful connection, which we will not get.
|
||||||
|
if err := ss.StartServer(grpc.Creds(serverCreds)); err != nil {
|
||||||
|
t.Fatalf("Error starting server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
cc, err := grpc.NewClient("dns:"+ss.Address, grpc.WithTransportCredentials(clientCreds))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("grpc.NewClient error: %v", err)
|
||||||
|
}
|
||||||
|
defer cc.Close()
|
||||||
|
|
||||||
|
client := testgrpc.NewTestServiceClient(cc)
|
||||||
|
|
||||||
|
const wantStr = "authentication handshake failed"
|
||||||
|
if _, err = client.EmptyCall(ctx, &testpb.Empty{}); status.Code(err) != codes.Unavailable || !strings.Contains(status.Convert(err).Message(), wantStr) {
|
||||||
|
t.Fatalf("EmptyCall err = %v; want code=%v, message contains %q", err, codes.Unavailable, wantStr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests that CipherSuites is not overridden when it is set.
|
||||||
|
func (s) TestTLS_CipherSuitesOverridable(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
serverTLS func() *tls.Config
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "base_case",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback_to_base",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
config := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
|
||||||
|
}
|
||||||
|
config.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "dynamic_using_get_config_for_client",
|
||||||
|
serverTLS: func() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create server that allows only a forbidden cipher suite.
|
||||||
|
serverCreds := credentials.NewTLSWithALPNDisabled(tc.serverTLS())
|
||||||
|
ss := stubserver.StubServer{
|
||||||
|
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
return &testpb.Empty{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create server that allows only a forbidden cipher suite.
|
||||||
|
clientCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
RootCAs: certPool,
|
||||||
|
CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA},
|
||||||
|
MaxVersion: tls.VersionTLS12, // TLS1.3 cipher suites are not configurable, so limit to 1.2.
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
|
||||||
|
t.Fatalf("Error starting stub server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||||
|
t.Fatalf("EmptyCall err = %v; want <nil>", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTLS_ServerConfiguresALPNByDefault verifies that ALPN is configured
|
||||||
|
// correctly for a server that doesn't specify the NextProtos field and uses
|
||||||
|
// GetConfigForClient to provide the TLS config during the handshake.
|
||||||
|
func (s) TestTLS_ServerConfiguresALPNByDefault(t *testing.T) {
|
||||||
|
initialVal := envconfig.EnforceALPNEnabled
|
||||||
|
defer func() {
|
||||||
|
envconfig.EnforceALPNEnabled = initialVal
|
||||||
|
}()
|
||||||
|
envconfig.EnforceALPNEnabled = true
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create a server that doesn't set the NextProtos field.
|
||||||
|
serverCreds := credentials.NewTLSWithALPNDisabled(&tls.Config{
|
||||||
|
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ss := stubserver.StubServer{
|
||||||
|
EmptyCallF: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
return &testpb.Empty{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCreds := credsstable.NewTLS(&tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
RootCAs: certPool,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := ss.Start([]grpc.ServerOption{grpc.Creds(serverCreds)}, grpc.WithTransportCredentials(clientCreds)); err != nil {
|
||||||
|
t.Fatalf("Error starting stub server: %v", err)
|
||||||
|
}
|
||||||
|
defer ss.Stop()
|
||||||
|
|
||||||
|
if _, err := ss.Client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
|
||||||
|
t.Fatalf("EmptyCall err = %v; want <nil>", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTLS_DisabledALPNClient tests the behaviour of TransportCredentials when
|
||||||
|
// connecting to a server that doesn't support ALPN.
|
||||||
|
func (s) TestTLS_DisabledALPNClient(t *testing.T) {
|
||||||
|
initialVal := envconfig.EnforceALPNEnabled
|
||||||
|
defer func() {
|
||||||
|
envconfig.EnforceALPNEnabled = initialVal
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
alpnEnforced bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enforced",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_enforced",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
envconfig.EnforceALPNEnabled = tc.alpnEnforced
|
||||||
|
|
||||||
|
listener, err := tls.Listen("tcp", "localhost:0", &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error starting TLS server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
|
||||||
|
} else {
|
||||||
|
// The first write to the TLS listener initiates the TLS handshake.
|
||||||
|
conn.Write([]byte("Hello, World!"))
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
close(errCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
serverAddr := listener.Addr().String()
|
||||||
|
conn, err := net.Dial("tcp", serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("net.Dial(%s) failed: %v", serverAddr, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
clientCfg := tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
RootCAs: certPool,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
}
|
||||||
|
_, _, err = credentials.NewTLSWithALPNDisabled(&clientCfg).ClientHandshake(ctx, serverName, conn)
|
||||||
|
|
||||||
|
if gotErr := (err != nil); gotErr != tc.wantErr {
|
||||||
|
t.Errorf("ClientHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error received from server: %v", err)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatalf("Timeout waiting for error from server")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTLS_DisabledALPNServer tests the behaviour of TransportCredentials when
|
||||||
|
// accepting a request from a client that doesn't support ALPN.
|
||||||
|
func (s) TestTLS_DisabledALPNServer(t *testing.T) {
|
||||||
|
initialVal := envconfig.EnforceALPNEnabled
|
||||||
|
defer func() {
|
||||||
|
envconfig.EnforceALPNEnabled = initialVal
|
||||||
|
}()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
alpnEnforced bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "enforced",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not_enforced",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
envconfig.EnforceALPNEnabled = tc.alpnEnforced
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error starting server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("listener.Accept returned error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
serverCfg := tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
}
|
||||||
|
_, _, err = credentials.NewTLSWithALPNDisabled(&serverCfg).ServerHandshake(conn)
|
||||||
|
if gotErr := (err != nil); gotErr != tc.wantErr {
|
||||||
|
t.Errorf("ServerHandshake returned unexpected error: got=%v, want=%t", err, tc.wantErr)
|
||||||
|
}
|
||||||
|
close(errCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
serverAddr := listener.Addr().String()
|
||||||
|
clientCfg := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{serverCert},
|
||||||
|
NextProtos: []string{}, // Empty list indicates ALPN is disabled.
|
||||||
|
RootCAs: certPool,
|
||||||
|
ServerName: serverName,
|
||||||
|
}
|
||||||
|
conn, err := tls.Dial("tcp", serverAddr, clientCfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls.Dial(%s) failed: %v", serverAddr, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(defaultTestTimeout):
|
||||||
|
t.Fatal("Timed out waiting for completion")
|
||||||
|
case err := <-errCh:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected server error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue