mirror of https://github.com/grpc/grpc-go.git
xDS: Atomically read and write xDS security configuration client side (#6796)
This commit is contained in:
parent
ce3b538586
commit
59c0aec9dc
|
@ -27,6 +27,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
@ -114,7 +115,9 @@ func (c *credsImpl) ClientHandshake(ctx context.Context, authority string, rawCo
|
|||
if chi.Attributes == nil {
|
||||
return c.fallback.ClientHandshake(ctx, authority, rawConn)
|
||||
}
|
||||
hi := xdsinternal.GetHandshakeInfo(chi.Attributes)
|
||||
|
||||
uPtr := xdsinternal.GetHandshakeInfo(chi.Attributes)
|
||||
hi := (*xdsinternal.HandshakeInfo)(atomic.LoadPointer(uPtr))
|
||||
if hi.UseFallbackCreds() {
|
||||
return c.fallback.ClientHandshake(ctx, authority, rawConn)
|
||||
}
|
||||
|
|
|
@ -27,8 +27,10 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
|
@ -219,11 +221,13 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
|
|||
// Creating the HandshakeInfo and adding it to the attributes is very
|
||||
// similar to what the CDS balancer would do when it intercepts calls to
|
||||
// NewSubConn().
|
||||
info := xdsinternal.NewHandshakeInfo(root, identity)
|
||||
var sms []matcher.StringMatcher
|
||||
if sanExactMatch != "" {
|
||||
info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
|
||||
sms = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
|
||||
}
|
||||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
|
||||
info := xdsinternal.NewHandshakeInfo(root, identity, sms, false)
|
||||
uPtr := unsafe.Pointer(info)
|
||||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
|
||||
|
||||
// Moving the attributes from the resolver.Address to the context passed to
|
||||
// the handshaker is done in the transport layer. Since we directly call the
|
||||
|
@ -533,13 +537,12 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
|
|||
// Create a root provider which will fail the handshake because it does not
|
||||
// use the correct trust roots.
|
||||
root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
|
||||
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
|
||||
handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})
|
||||
|
||||
handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
|
||||
// We need to repeat most of what newTestContextWithHandshakeInfo() does
|
||||
// here because we need access to the underlying HandshakeInfo so that we
|
||||
// can update it before the next call to ClientHandshake().
|
||||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
|
||||
uPtr := unsafe.Pointer(handshakeInfo)
|
||||
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
|
||||
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
|
||||
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
|
||||
t.Fatal("ClientHandshake() succeeded when expected to fail")
|
||||
|
@ -560,7 +563,10 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
|
|||
// Create a new root provider which uses the correct trust roots. And update
|
||||
// the HandshakeInfo with the new provider.
|
||||
root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
|
||||
handshakeInfo.SetRootCertProvider(root2)
|
||||
handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
|
||||
// Update the existing pointer, which address attribute will continue to
|
||||
// point to.
|
||||
atomic.StorePointer(&uPtr, unsafe.Pointer(handshakeInfo))
|
||||
_, ai, err := creds.ClientHandshake(ctx, authority, conn)
|
||||
if err != nil {
|
||||
t.Fatalf("ClientHandshake() returned failed: %q", err)
|
||||
|
|
|
@ -122,7 +122,7 @@ func (s) TestServerCredsInvalidHandshakeInfo(t *testing.T) {
|
|||
t.Fatalf("NewServerCredentials(%v) failed: %v", opts, err)
|
||||
}
|
||||
|
||||
info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil)
|
||||
info := xdsinternal.NewHandshakeInfo(&fakeProvider{}, nil, nil, false)
|
||||
conn := newWrappedConn(nil, info, time.Time{})
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil {
|
||||
t.Fatal("ServerHandshake succeeded without identity certificate provider in HandshakeInfo")
|
||||
|
@ -158,7 +158,7 @@ func (s) TestServerCredsProviderFailure(t *testing.T) {
|
|||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
|
||||
info := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, false)
|
||||
conn := newWrappedConn(nil, info, time.Time{})
|
||||
if _, _, err := creds.ServerHandshake(conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
|
||||
t.Fatalf("ServerHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
|
||||
|
@ -232,8 +232,7 @@ func (s) TestServerCredsHandshakeTimeout(t *testing.T) {
|
|||
// Create a test server which uses the xDS server credentials created above
|
||||
// to perform TLS handshake on incoming connections.
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server2_cert.pem", "x509/server2_key.pem"), nil, true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo created
|
||||
// above with a very small deadline.
|
||||
|
@ -285,8 +284,7 @@ func (s) TestServerCredsHandshakeFailure(t *testing.T) {
|
|||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
// Create a HandshakeInfo which has a root provider which does not match
|
||||
// the certificate sent by the client.
|
||||
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
hi := xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
|
@ -367,8 +365,7 @@ func (s) TestServerCredsHandshakeSuccess(t *testing.T) {
|
|||
// created above to perform TLS handshake on incoming connections.
|
||||
ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
|
||||
// Create a HandshakeInfo with information from the test table.
|
||||
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider)
|
||||
hi.SetRequireClientCert(test.requireClientCert)
|
||||
hi := xdsinternal.NewHandshakeInfo(test.rootProvider, test.identityProvider, nil, test.requireClientCert)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
|
@ -448,8 +445,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
|
|||
if cnt == 1 {
|
||||
// Create a HandshakeInfo which has a root provider which does not match
|
||||
// the certificate sent by the client.
|
||||
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/client2_cert.pem", "x509/client2_key.pem"), nil, true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
|
@ -463,8 +459,7 @@ func (s) TestServerCredsProviderSwitch(t *testing.T) {
|
|||
return handshakeResult{}
|
||||
}
|
||||
|
||||
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"))
|
||||
hi.SetRequireClientCert(true)
|
||||
hi = xdsinternal.NewHandshakeInfo(makeRootProvider(t, "x509/client_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), nil, true)
|
||||
|
||||
// Create a wrapped conn which can return the HandshakeInfo and
|
||||
// configured deadline to the xDS credentials' ServerHandshake()
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"google.golang.org/grpc/attributes"
|
||||
"google.golang.org/grpc/credentials/tls/certprovider"
|
||||
|
@ -66,59 +66,38 @@ func (hi *HandshakeInfo) Equal(other *HandshakeInfo) bool {
|
|||
}
|
||||
|
||||
// SetHandshakeInfo returns a copy of addr in which the Attributes field is
|
||||
// updated with hInfo.
|
||||
func SetHandshakeInfo(addr resolver.Address, hInfo *HandshakeInfo) resolver.Address {
|
||||
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hInfo)
|
||||
// updated with hiPtr.
|
||||
func SetHandshakeInfo(addr resolver.Address, hiPtr *unsafe.Pointer) resolver.Address {
|
||||
addr.Attributes = addr.Attributes.WithValue(handshakeAttrKey{}, hiPtr)
|
||||
return addr
|
||||
}
|
||||
|
||||
// GetHandshakeInfo returns a pointer to the HandshakeInfo stored in attr.
|
||||
func GetHandshakeInfo(attr *attributes.Attributes) *HandshakeInfo {
|
||||
// GetHandshakeInfo returns a pointer to the *HandshakeInfo stored in attr.
|
||||
func GetHandshakeInfo(attr *attributes.Attributes) *unsafe.Pointer {
|
||||
v := attr.Value(handshakeAttrKey{})
|
||||
hi, _ := v.(*HandshakeInfo)
|
||||
hi, _ := v.(*unsafe.Pointer)
|
||||
return hi
|
||||
}
|
||||
|
||||
// HandshakeInfo wraps all the security configuration required by client and
|
||||
// server handshake methods in xds credentials. The xDS implementation will be
|
||||
// responsible for populating these fields.
|
||||
//
|
||||
// Safe for concurrent access.
|
||||
type HandshakeInfo struct {
|
||||
mu sync.Mutex
|
||||
// All fields written at init time and read only after that, so no
|
||||
// synchronization needed.
|
||||
rootProvider certprovider.Provider
|
||||
identityProvider certprovider.Provider
|
||||
sanMatchers []matcher.StringMatcher // Only on the client side.
|
||||
requireClientCert bool // Only on server side.
|
||||
}
|
||||
|
||||
// SetRootCertProvider updates the root certificate provider.
|
||||
func (hi *HandshakeInfo) SetRootCertProvider(root certprovider.Provider) {
|
||||
hi.mu.Lock()
|
||||
hi.rootProvider = root
|
||||
hi.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetIdentityCertProvider updates the identity certificate provider.
|
||||
func (hi *HandshakeInfo) SetIdentityCertProvider(identity certprovider.Provider) {
|
||||
hi.mu.Lock()
|
||||
hi.identityProvider = identity
|
||||
hi.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetSANMatchers updates the list of SAN matchers.
|
||||
func (hi *HandshakeInfo) SetSANMatchers(sanMatchers []matcher.StringMatcher) {
|
||||
hi.mu.Lock()
|
||||
hi.sanMatchers = sanMatchers
|
||||
hi.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetRequireClientCert updates whether a client cert is required during the
|
||||
// ServerHandshake(). A value of true indicates that we are performing mTLS.
|
||||
func (hi *HandshakeInfo) SetRequireClientCert(require bool) {
|
||||
hi.mu.Lock()
|
||||
hi.requireClientCert = require
|
||||
hi.mu.Unlock()
|
||||
func NewHandshakeInfo(rootProvider certprovider.Provider, identityProvider certprovider.Provider, sanMatchers []matcher.StringMatcher, requireClientCert bool) *HandshakeInfo {
|
||||
return &HandshakeInfo{
|
||||
rootProvider: rootProvider,
|
||||
identityProvider: identityProvider,
|
||||
sanMatchers: sanMatchers,
|
||||
requireClientCert: requireClientCert,
|
||||
}
|
||||
}
|
||||
|
||||
// UseFallbackCreds returns true when fallback credentials are to be used based
|
||||
|
@ -127,24 +106,18 @@ func (hi *HandshakeInfo) UseFallbackCreds() bool {
|
|||
if hi == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
hi.mu.Lock()
|
||||
defer hi.mu.Unlock()
|
||||
return hi.identityProvider == nil && hi.rootProvider == nil
|
||||
}
|
||||
|
||||
// GetSANMatchersForTesting returns the SAN matchers stored in HandshakeInfo.
|
||||
// To be used only for testing purposes.
|
||||
func (hi *HandshakeInfo) GetSANMatchersForTesting() []matcher.StringMatcher {
|
||||
hi.mu.Lock()
|
||||
defer hi.mu.Unlock()
|
||||
return append([]matcher.StringMatcher{}, hi.sanMatchers...)
|
||||
}
|
||||
|
||||
// ClientSideTLSConfig constructs a tls.Config to be used in a client-side
|
||||
// handshake based on the contents of the HandshakeInfo.
|
||||
func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config, error) {
|
||||
hi.mu.Lock()
|
||||
// On the client side, rootProvider is mandatory. IdentityProvider is
|
||||
// optional based on whether the client is doing TLS or mTLS.
|
||||
if hi.rootProvider == nil {
|
||||
|
@ -153,7 +126,6 @@ func (hi *HandshakeInfo) ClientSideTLSConfig(ctx context.Context) (*tls.Config,
|
|||
// Since the call to KeyMaterial() can block, we read the providers under
|
||||
// the lock but call the actual function after releasing the lock.
|
||||
rootProv, idProv := hi.rootProvider, hi.identityProvider
|
||||
hi.mu.Unlock()
|
||||
|
||||
// InsecureSkipVerify needs to be set to true because we need to perform
|
||||
// custom verification to check the SAN on the received certificate.
|
||||
|
@ -188,7 +160,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
|
|||
ClientAuth: tls.NoClientCert,
|
||||
NextProtos: []string{"h2"},
|
||||
}
|
||||
hi.mu.Lock()
|
||||
// On the server side, identityProvider is mandatory. RootProvider is
|
||||
// optional based on whether the server is doing TLS or mTLS.
|
||||
if hi.identityProvider == nil {
|
||||
|
@ -200,7 +171,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
|
|||
if hi.requireClientCert {
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
hi.mu.Unlock()
|
||||
|
||||
// identityProvider is mandatory on the server side.
|
||||
km, err := idProv.KeyMaterial(ctx)
|
||||
|
@ -225,8 +195,6 @@ func (hi *HandshakeInfo) ServerSideTLSConfig(ctx context.Context) (*tls.Config,
|
|||
// If the list of SAN matchers in the HandshakeInfo is empty, this function
|
||||
// returns true for all input certificates.
|
||||
func (hi *HandshakeInfo) MatchingSANExists(cert *x509.Certificate) bool {
|
||||
hi.mu.Lock()
|
||||
defer hi.mu.Unlock()
|
||||
if len(hi.sanMatchers) == 0 {
|
||||
return true
|
||||
}
|
||||
|
@ -325,9 +293,3 @@ func dnsMatch(host, san string) bool {
|
|||
hostPrefix := strings.TrimSuffix(host, san[1:])
|
||||
return !strings.Contains(hostPrefix, ".")
|
||||
}
|
||||
|
||||
// NewHandshakeInfo returns a new instance of HandshakeInfo with the given root
|
||||
// and identity certificate providers.
|
||||
func NewHandshakeInfo(root, identity certprovider.Provider) *HandshakeInfo {
|
||||
return &HandshakeInfo{rootProvider: root, identityProvider: identity}
|
||||
}
|
||||
|
|
|
@ -188,8 +188,7 @@ func TestMatchingSANExists_FailureCases(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
hi := NewHandshakeInfo(nil, nil)
|
||||
hi.SetSANMatchers(test.sanMatchers)
|
||||
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)
|
||||
|
||||
if hi.MatchingSANExists(inputCert) {
|
||||
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v succeeded when expected to fail", inputCert, test.sanMatchers)
|
||||
|
@ -289,8 +288,7 @@ func TestMatchingSANExists_Success(t *testing.T) {
|
|||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
hi := NewHandshakeInfo(nil, nil)
|
||||
hi.SetSANMatchers(test.sanMatchers)
|
||||
hi := NewHandshakeInfo(nil, nil, test.sanMatchers, false)
|
||||
|
||||
if !hi.MatchingSANExists(inputCert) {
|
||||
t.Fatalf("hi.MatchingSANExists(%+v) with SAN matchers +%v failed when expected to succeed", inputCert, test.sanMatchers)
|
||||
|
|
|
@ -57,7 +57,7 @@ var (
|
|||
// GetXDSHandshakeInfoForTesting returns a pointer to the xds.HandshakeInfo
|
||||
// stored in the passed in attributes. This is set by
|
||||
// credentials/xds/xds.go.
|
||||
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *xds.HandshakeInfo
|
||||
GetXDSHandshakeInfoForTesting any // func (*attributes.Attributes) *unsafe.Pointer
|
||||
// GetServerCredentials returns the transport credentials configured on a
|
||||
// gRPC server. An xDS-enabled server needs to know what type of credentials
|
||||
// is configured on the underlying gRPC server. This is set by server.go.
|
||||
|
|
|
@ -21,6 +21,8 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"google.golang.org/grpc/balancer"
|
||||
"google.golang.org/grpc/balancer/base"
|
||||
|
@ -89,19 +91,21 @@ func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Bal
|
|||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
hi := xdsinternal.NewHandshakeInfo(nil, nil, nil, false)
|
||||
xdsHIPtr := unsafe.Pointer(hi)
|
||||
b := &cdsBalancer{
|
||||
bOpts: opts,
|
||||
childConfigParser: parser,
|
||||
serializer: grpcsync.NewCallbackSerializer(ctx),
|
||||
serializerCancel: cancel,
|
||||
xdsHI: xdsinternal.NewHandshakeInfo(nil, nil),
|
||||
xdsHIPtr: &xdsHIPtr,
|
||||
watchers: make(map[string]*watcherState),
|
||||
}
|
||||
b.ccw = &ccWrapper{
|
||||
ClientConn: cc,
|
||||
xdsHI: b.xdsHI,
|
||||
xdsHIPtr: b.xdsHIPtr,
|
||||
}
|
||||
b.logger = prefixLogger((b))
|
||||
b.logger = prefixLogger(b)
|
||||
b.logger.Infof("Created")
|
||||
|
||||
var creds credentials.TransportCredentials
|
||||
|
@ -149,11 +153,13 @@ type cdsBalancer struct {
|
|||
// The following fields are initialized at build time and are either
|
||||
// read-only after that or provide their own synchronization, and therefore
|
||||
// do not need to be guarded by a mutex.
|
||||
ccw *ccWrapper // ClientConn interface passed to child LB.
|
||||
bOpts balancer.BuildOptions // BuildOptions passed to child LB.
|
||||
childConfigParser balancer.ConfigParser // Config parser for cluster_resolver LB policy.
|
||||
xdsHI *xdsinternal.HandshakeInfo // Handshake info from security configuration.
|
||||
logger *grpclog.PrefixLogger // Prefix logger for all logging.
|
||||
ccw *ccWrapper // ClientConn interface passed to child LB.
|
||||
bOpts balancer.BuildOptions // BuildOptions passed to child LB.
|
||||
childConfigParser balancer.ConfigParser // Config parser for cluster_resolver LB policy.
|
||||
logger *grpclog.PrefixLogger // Prefix logger for all logging.
|
||||
xdsCredsInUse bool
|
||||
|
||||
xdsHIPtr *unsafe.Pointer // Accessed atomically.
|
||||
|
||||
// The serializer and its cancel func are initialized at build time, and the
|
||||
// rest of the fields here are only accessed from serializer callbacks (or
|
||||
|
@ -170,7 +176,6 @@ type cdsBalancer struct {
|
|||
// a new provider is to be created.
|
||||
cachedRoot certprovider.Provider
|
||||
cachedIdentity certprovider.Provider
|
||||
xdsCredsInUse bool
|
||||
}
|
||||
|
||||
// handleSecurityConfig processes the security configuration received from the
|
||||
|
@ -186,6 +191,7 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) e
|
|||
if !b.xdsCredsInUse {
|
||||
return nil
|
||||
}
|
||||
var xdsHI *xdsinternal.HandshakeInfo
|
||||
|
||||
// Security config being nil is a valid case where the management server has
|
||||
// not sent any security configuration. The xdsCredentials implementation
|
||||
|
@ -194,10 +200,10 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) e
|
|||
// We need to explicitly set the fields to nil here since this might be
|
||||
// a case of switching from a good security configuration to an empty
|
||||
// one where fallback credentials are to be used.
|
||||
b.xdsHI.SetRootCertProvider(nil)
|
||||
b.xdsHI.SetIdentityCertProvider(nil)
|
||||
b.xdsHI.SetSANMatchers(nil)
|
||||
xdsHI = xdsinternal.NewHandshakeInfo(nil, nil, nil, false)
|
||||
atomic.StorePointer(b.xdsHIPtr, unsafe.Pointer(xdsHI))
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
bc := b.xdsClient.BootstrapConfig()
|
||||
|
@ -234,12 +240,8 @@ func (b *cdsBalancer) handleSecurityConfig(config *xdsresource.SecurityConfig) e
|
|||
}
|
||||
b.cachedRoot = rootProvider
|
||||
b.cachedIdentity = identityProvider
|
||||
|
||||
// We set all fields here, even if some of them are nil, since they
|
||||
// could have been non-nil earlier.
|
||||
b.xdsHI.SetRootCertProvider(rootProvider)
|
||||
b.xdsHI.SetIdentityCertProvider(identityProvider)
|
||||
b.xdsHI.SetSANMatchers(config.SubjectAltNameMatchers)
|
||||
xdsHI = xdsinternal.NewHandshakeInfo(rootProvider, identityProvider, config.SubjectAltNameMatchers, false)
|
||||
atomic.StorePointer(b.xdsHIPtr, unsafe.Pointer(xdsHI))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -660,9 +662,7 @@ func (b *cdsBalancer) generateDMsForCluster(name string, depth int, dms []cluste
|
|||
type ccWrapper struct {
|
||||
balancer.ClientConn
|
||||
|
||||
// The certificate providers in this HandshakeInfo are updated based on the
|
||||
// received security configuration in the Cluster resource.
|
||||
xdsHI *xdsinternal.HandshakeInfo
|
||||
xdsHIPtr *unsafe.Pointer
|
||||
}
|
||||
|
||||
// NewSubConn intercepts NewSubConn() calls from the child policy and adds an
|
||||
|
@ -671,8 +671,9 @@ type ccWrapper struct {
|
|||
func (ccw *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) {
|
||||
newAddrs := make([]resolver.Address, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHI)
|
||||
newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHIPtr)
|
||||
}
|
||||
|
||||
// No need to override opts.StateListener; just forward all calls to the
|
||||
// child that created the SubConn.
|
||||
return ccw.ClientConn.NewSubConn(newAddrs, opts)
|
||||
|
@ -681,7 +682,7 @@ func (ccw *ccWrapper) NewSubConn(addrs []resolver.Address, opts balancer.NewSubC
|
|||
func (ccw *ccWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
|
||||
newAddrs := make([]resolver.Address, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHI)
|
||||
newAddrs[i] = xdsinternal.SetHandshakeInfo(addr, ccw.xdsHIPtr)
|
||||
}
|
||||
ccw.ClientConn.UpdateAddresses(sc, newAddrs)
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"unsafe"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -75,14 +76,15 @@ func (tcc *testCCWrapper) NewSubConn(addrs []resolver.Address, opts balancer.New
|
|||
if len(addrs) != 1 {
|
||||
return nil, fmt.Errorf("NewSubConn got %d addresses, want 1", len(addrs))
|
||||
}
|
||||
getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *xdscredsinternal.HandshakeInfo)
|
||||
getHI := internal.GetXDSHandshakeInfoForTesting.(func(attr *attributes.Attributes) *unsafe.Pointer)
|
||||
hi := getHI(addrs[0].Attributes)
|
||||
if hi == nil {
|
||||
return nil, fmt.Errorf("NewSubConn got address without xDS handshake info")
|
||||
}
|
||||
|
||||
sc, err := tcc.ClientConn.NewSubConn(addrs, opts)
|
||||
select {
|
||||
case tcc.handshakeInfoCh <- hi:
|
||||
case tcc.handshakeInfoCh <- (*xdscredsinternal.HandshakeInfo)(*hi):
|
||||
default:
|
||||
}
|
||||
return sc, err
|
||||
|
@ -292,7 +294,7 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) {
|
|||
case <-ctx.Done():
|
||||
t.Fatal("Timeout when waiting to read handshake info passed to NewSubConn")
|
||||
}
|
||||
wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil)
|
||||
wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil, nil, false)
|
||||
if !cmp.Equal(gotHI, wantHI) {
|
||||
t.Fatalf("NewSubConn got handshake info %+v, want %+v", gotHI, wantHI)
|
||||
}
|
||||
|
@ -343,7 +345,7 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) {
|
|||
case <-ctx.Done():
|
||||
t.Fatal("Timeout when waiting to read handshake info passed to NewSubConn")
|
||||
}
|
||||
wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil)
|
||||
wantHI := xdscredsinternal.NewHandshakeInfo(nil, nil, nil, false)
|
||||
if !cmp.Equal(gotHI, wantHI) {
|
||||
t.Fatalf("NewSubConn got handshake info %+v, want %+v", gotHI, wantHI)
|
||||
}
|
||||
|
|
|
@ -106,7 +106,7 @@ func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) {
|
|||
// did not provide any security configuration and therefore we should
|
||||
// return an empty HandshakeInfo here so that the xdsCreds can use the
|
||||
// configured fallback credentials.
|
||||
return xdsinternal.NewHandshakeInfo(nil, nil), nil
|
||||
return xdsinternal.NewHandshakeInfo(nil, nil, nil, false), nil
|
||||
}
|
||||
|
||||
cpc := c.parent.xdsC.BootstrapConfig().CertProviderConfigs
|
||||
|
@ -128,9 +128,7 @@ func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) {
|
|||
c.identityProvider = ip
|
||||
c.rootProvider = rp
|
||||
|
||||
xdsHI := xdsinternal.NewHandshakeInfo(c.rootProvider, c.identityProvider)
|
||||
xdsHI.SetRequireClientCert(secCfg.RequireClientCert)
|
||||
return xdsHI, nil
|
||||
return xdsinternal.NewHandshakeInfo(c.rootProvider, c.identityProvider, nil, secCfg.RequireClientCert), nil
|
||||
}
|
||||
|
||||
// Close closes the providers and the underlying connection.
|
||||
|
|
Loading…
Reference in New Issue