credentials/google: remove unnecessary dependency on xds protos (#4339)

This commit is contained in:
Doug Fawley 2021-04-13 16:19:17 -07:00 committed by GitHub
parent 6fafb9193b
commit 87eb5b7502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 96 additions and 78 deletions

View File

@ -30,7 +30,7 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/attributes"
"google.golang.org/grpc/internal"
icredentials "google.golang.org/grpc/internal/credentials"
)
// PerRPCCredentials defines the common interface for the credentials which need to
@ -188,15 +188,12 @@ type RequestInfo struct {
AuthInfo AuthInfo
}
// requestInfoKey is a struct to be used as the key when attaching a RequestInfo to a context object.
type requestInfoKey struct{}
// RequestInfoFromContext extracts the RequestInfo from the context if it exists.
//
// This API is experimental.
func RequestInfoFromContext(ctx context.Context) (ri RequestInfo, ok bool) {
ri, ok = ctx.Value(requestInfoKey{}).(RequestInfo)
return
ri, ok = icredentials.RequestInfoFromContext(ctx).(RequestInfo)
return ri, ok
}
// ClientHandshakeInfo holds data to be passed to ClientHandshake. This makes
@ -211,16 +208,12 @@ type ClientHandshakeInfo struct {
Attributes *attributes.Attributes
}
// clientHandshakeInfoKey is a struct used as the key to store
// ClientHandshakeInfo in a context.
type clientHandshakeInfoKey struct{}
// ClientHandshakeInfoFromContext returns the ClientHandshakeInfo struct stored
// in ctx.
//
// This API is experimental.
func ClientHandshakeInfoFromContext(ctx context.Context) ClientHandshakeInfo {
chi, _ := ctx.Value(clientHandshakeInfoKey{}).(ClientHandshakeInfo)
chi, _ := icredentials.ClientHandshakeInfoFromContext(ctx).(ClientHandshakeInfo)
return chi
}
@ -249,15 +242,6 @@ func CheckSecurityLevel(ai AuthInfo, level SecurityLevel) error {
return nil
}
func init() {
internal.NewRequestInfoContext = func(ctx context.Context, ri RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
}
internal.NewClientHandshakeInfoContext = func(ctx context.Context, chi ClientHandshakeInfo) context.Context {
return context.WithValue(ctx, clientHandshakeInfoKey{}, chi)
}
}
// ChannelzSecurityInfo defines the interface that security protocols should implement
// in order to provide security info to channelz.
//

View File

@ -25,7 +25,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/resolver"
)
@ -53,8 +53,6 @@ func (t *testAuthInfo) AuthType() string {
var (
testTLS = &testCreds{typ: "tls"}
testALTS = &testCreds{typ: "alts"}
contextWithHandshakeInfo = internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
)
func overrideNewCredsFuncs() func() {
@ -93,16 +91,16 @@ func TestClientHandshakeBasedOnClusterName(t *testing.T) {
},
{
name: "with non-CFE cluster name",
ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{
Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
}),
// non-CFE backends should use alts.
wantTyp: "alts",
},
{
name: "with CFE cluster name",
ctx: contextWithHandshakeInfo(context.Background(), credentials.ClientHandshakeInfo{
Attributes: xdsinternal.SetHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes,
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cfeClusterName).Attributes,
}),
// CFE should use tls.
wantTyp: "tls",

View File

@ -23,7 +23,7 @@ import (
"net"
"google.golang.org/grpc/credentials"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal"
)
const cfeClusterName = "google-cfe"
@ -54,7 +54,7 @@ func (c *clusterTransportCreds) ClientHandshake(ctx context.Context, authority s
if chi.Attributes == nil {
return c.tls.ClientHandshake(ctx, authority, rawConn)
}
cn, ok := xdsinternal.GetHandshakeClusterName(chi.Attributes)
cn, ok := internal.GetXDSHandshakeClusterName(chi.Attributes)
if !ok || cn == cfeClusterName {
return c.tls.ClientHandshake(ctx, authority, rawConn)
}

View File

@ -37,7 +37,7 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
)
@ -104,7 +104,7 @@ func createTestContext(ctx context.Context, s credentials.SecurityLevel) context
Method: "testInfo",
AuthInfo: auth,
}
return internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
return icredentials.NewRequestInfoContext(ctx, ri)
}
// errReader implements the io.Reader interface and returns an error from the

View File

@ -32,7 +32,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/tls/certprovider"
"google.golang.org/grpc/internal"
icredentials "google.golang.org/grpc/internal/credentials"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/testutils"
@ -228,8 +228,7 @@ func newTestContextWithHandshakeInfo(parent context.Context, root, identity cert
// Moving the attributes from the resolver.Address to the context passed to
// the handshaker is done in the transport layer. Since we directly call the
// handshaker in these tests, we need to do the same here.
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
return contextWithHandshakeInfo(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
}
// compareAuthInfo compares the AuthInfo received on the client side after a
@ -541,8 +540,7 @@ func (s) TestClientCredsProviderSwitch(t *testing.T) {
// here because we need access to the underlying HandshakeInfo so that we
// can update it before the next call to ClientHandshake().
addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
ctx = contextWithHandshakeInfo(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
t.Fatal("ClientHandshake() succeeded when expected to fail")
}

View File

@ -0,0 +1,49 @@
/*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package credentials
import (
"context"
)
// requestInfoKey is a struct to be used as the key to store RequestInfo in a
// context.
type requestInfoKey struct{}
// NewRequestInfoContext creates a context with ri.
func NewRequestInfoContext(ctx context.Context, ri interface{}) context.Context {
return context.WithValue(ctx, requestInfoKey{}, ri)
}
// RequestInfoFromContext extracts the RequestInfo from ctx.
func RequestInfoFromContext(ctx context.Context) interface{} {
return ctx.Value(requestInfoKey{})
}
// clientHandshakeInfoKey is a struct used as the key to store
// ClientHandshakeInfo in a context.
type clientHandshakeInfoKey struct{}
// ClientHandshakeInfoFromContext extracts the ClientHandshakeInfo from ctx.
func ClientHandshakeInfoFromContext(ctx context.Context) interface{} {
return ctx.Value(clientHandshakeInfoKey{})
}
// NewClientHandshakeInfoContext creates a context with chi.
func NewClientHandshakeInfoContext(ctx context.Context, chi interface{}) context.Context {
return context.WithValue(ctx, clientHandshakeInfoKey{}, chi)
}

View File

@ -38,12 +38,6 @@ var (
// KeepaliveMinPingTime is the minimum ping interval. This must be 10s by
// default, but tests may wish to set it lower for convenience.
KeepaliveMinPingTime = 10 * time.Second
// NewRequestInfoContext creates a new context based on the argument context attaching
// the passed in RequestInfo to the new context.
NewRequestInfoContext interface{} // func(context.Context, credentials.RequestInfo) context.Context
// NewClientHandshakeInfoContext returns a copy of the input context with
// the passed in ClientHandshakeInfo struct added to it.
NewClientHandshakeInfoContext interface{} // func(context.Context, credentials.ClientHandshakeInfo) context.Context
// ParseServiceConfigForTesting is for creating a fake
// ClientConn for resolver testing only
ParseServiceConfigForTesting interface{} // func(string) *serviceconfig.ParseResult

View File

@ -32,15 +32,14 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/grpcutil"
imetadata "google.golang.org/grpc/internal/metadata"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/transport/networktype"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
@ -238,8 +237,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
// Attributes field of resolver.Address, which is shoved into connectCtx
// and passed to the credential handshaker. This makes it possible for
// address specific arbitrary data to reach the credential handshaker.
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
@ -458,7 +456,7 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr)
Method: callHdr.Method,
AuthInfo: t.authInfo,
}
ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, credentials.RequestInfo) context.Context)(ctx, ri)
ctxWithRequestInfo := icredentials.NewRequestInfoContext(ctx, ri)
authData, err := t.getTrAuthData(ctxWithRequestInfo, aud)
if err != nil {
return nil, err

View File

@ -1,5 +1,4 @@
/*
*
* Copyright 2021 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -13,10 +12,9 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package xds
package internal
import (
"google.golang.org/grpc/attributes"
@ -27,15 +25,15 @@ import (
// the Attributes field of resolver.Address.
type handshakeClusterNameKey struct{}
// SetHandshakeClusterName returns a copy of addr in which the Attributes field
// SetXDSHandshakeClusterName returns a copy of addr in which the Attributes field
// is updated with the cluster name.
func SetHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address {
func SetXDSHandshakeClusterName(addr resolver.Address, clusterName string) resolver.Address {
addr.Attributes = addr.Attributes.WithValues(handshakeClusterNameKey{}, clusterName)
return addr
}
// GetHandshakeClusterName returns cluster name stored in attr.
func GetHandshakeClusterName(attr *attributes.Attributes) (string, bool) {
// GetXDSHandshakeClusterName returns cluster name stored in attr.
func GetXDSHandshakeClusterName(attr *attributes.Attributes) (string, bool) {
v := attr.Value(handshakeClusterNameKey{})
name, ok := v.(string)
return name, ok

View File

@ -29,7 +29,7 @@ import (
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/xds/internal/client/load"
@ -414,7 +414,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) {
if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
cn, ok := xdsinternal.GetHandshakeClusterName(addrs1[0].Attributes)
cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes)
if !ok || cn != testClusterName {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName)
}
@ -455,7 +455,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
// New addresses should have the new cluster name.
cn2, ok := xdsinternal.GetHandshakeClusterName(addrs2[0].Attributes)
cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes)
if !ok || cn2 != testClusterName2 {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2)
}

View File

@ -29,8 +29,8 @@ import (
"sync"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/buffer"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/grpcsync"
"google.golang.org/grpc/resolver"
@ -327,7 +327,7 @@ func (cib *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balanc
clusterName := cib.getClusterName()
newAddrs := make([]resolver.Address, len(addrs))
for i, addr := range addrs {
newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName)
newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName)
}
return cib.ClientConn.NewSubConn(newAddrs, opts)
}
@ -336,7 +336,7 @@ func (cib *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []res
clusterName := cib.getClusterName()
newAddrs := make([]resolver.Address, len(addrs))
for i, addr := range addrs {
newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName)
newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName)
}
cib.ClientConn.UpdateAddresses(sc, newAddrs)
}

View File

@ -23,19 +23,18 @@ import (
"time"
"github.com/google/go-cmp/cmp"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/balancer/weightedroundrobin"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/grpclog"
"google.golang.org/grpc/internal/xds/env"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
"google.golang.org/grpc/xds/internal"
xdsi "google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/balancer/balancergroup"
"google.golang.org/grpc/xds/internal/balancer/weightedtarget/weightedaggregator"
"google.golang.org/grpc/xds/internal/client"
@ -58,7 +57,7 @@ type localityConfig struct {
type balancerGroupWithConfig struct {
bg *balancergroup.BalancerGroup
stateAggregator *weightedaggregator.Aggregator
configs map[internal.LocalityID]*localityConfig
configs map[xdsi.LocalityID]*localityConfig
}
// edsBalancerImpl does load balancing based on the EDS responses. Note that it
@ -261,7 +260,7 @@ func (edsImpl *edsBalancerImpl) handleEDSResponse(edsResp xdsclient.EndpointsUpd
bgwc = &balancerGroupWithConfig{
bg: balancergroup.New(ccPriorityWrapper, edsImpl.buildOpts, stateAggregator, edsImpl.loadReporter, edsImpl.logger),
stateAggregator: stateAggregator,
configs: make(map[internal.LocalityID]*localityConfig),
configs: make(map[xdsi.LocalityID]*localityConfig),
}
edsImpl.priorityToLocalities[priority] = bgwc
priorityChanged = true
@ -295,7 +294,7 @@ func (edsImpl *edsBalancerImpl) handleEDSResponsePerPriority(bgwc *balancerGroup
// newLocalitiesSet contains all names of localities in the new EDS response
// for the same priority. It's used to delete localities that are removed in
// the new EDS response.
newLocalitiesSet := make(map[internal.LocalityID]struct{})
newLocalitiesSet := make(map[xdsi.LocalityID]struct{})
var rebuildStateAndPicker bool
for _, locality := range newLocalities {
// One balancer for each locality.
@ -498,7 +497,7 @@ func (ebwcc *edsBalancerWrapperCC) NewSubConn(addrs []resolver.Address, opts bal
clusterName := ebwcc.parent.getClusterName()
newAddrs := make([]resolver.Address, len(addrs))
for i, addr := range addrs {
newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName)
newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName)
}
return ebwcc.parent.newSubConn(ebwcc.priority, newAddrs, opts)
}
@ -507,7 +506,7 @@ func (ebwcc *edsBalancerWrapperCC) UpdateAddresses(sc balancer.SubConn, addrs []
clusterName := ebwcc.parent.getClusterName()
newAddrs := make([]resolver.Address, len(addrs))
for i, addr := range addrs {
newAddrs[i] = xdsinternal.SetHandshakeClusterName(addr, clusterName)
newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName)
}
ebwcc.ClientConn.UpdateAddresses(sc, newAddrs)
}

View File

@ -26,14 +26,14 @@ import (
corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
xdsinternal "google.golang.org/grpc/internal/credentials/xds"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/xds/env"
"google.golang.org/grpc/xds/internal"
xdsinternal "google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/balancer/balancergroup"
"google.golang.org/grpc/xds/internal/client"
xdsclient "google.golang.org/grpc/xds/internal/client"
@ -834,7 +834,7 @@ func (s) TestEDS_LoadReport(t *testing.T) {
edsb.updateServiceRequestsConfig(testServiceName, &maxRequestsTemp)
defer client.ClearCounterForTesting(testServiceName)
backendToBalancerID := make(map[balancer.SubConn]internal.LocalityID)
backendToBalancerID := make(map[balancer.SubConn]xdsinternal.LocalityID)
const testDropCategory = "test-drop"
// Two localities, each with one backend.
@ -844,7 +844,7 @@ func (s) TestEDS_LoadReport(t *testing.T) {
sc1 := <-cc.NewSubConnCh
edsb.handleSubConnStateChange(sc1, connectivity.Connecting)
edsb.handleSubConnStateChange(sc1, connectivity.Ready)
locality1 := internal.LocalityID{SubZone: testSubZones[0]}
locality1 := xdsinternal.LocalityID{SubZone: testSubZones[0]}
backendToBalancerID[sc1] = locality1
// Add the second locality later to make sure sc2 belongs to the second
@ -855,7 +855,7 @@ func (s) TestEDS_LoadReport(t *testing.T) {
sc2 := <-cc.NewSubConnCh
edsb.handleSubConnStateChange(sc2, connectivity.Connecting)
edsb.handleSubConnStateChange(sc2, connectivity.Ready)
locality2 := internal.LocalityID{SubZone: testSubZones[1]}
locality2 := xdsinternal.LocalityID{SubZone: testSubZones[1]}
backendToBalancerID[sc2] = locality2
// Test roundrobin with two subconns.
@ -954,7 +954,7 @@ func (s) TestEDS_ClusterNameInAddressAttributes(t *testing.T) {
if got, want := addrs1[0].Addr, testEndpointAddrs[0]; got != want {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
cn, ok := xdsinternal.GetHandshakeClusterName(addrs1[0].Attributes)
cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes)
if !ok || cn != clusterName1 {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, clusterName1)
}
@ -986,7 +986,7 @@ func (s) TestEDS_ClusterNameInAddressAttributes(t *testing.T) {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
// New addresses should have the new cluster name.
cn2, ok := xdsinternal.GetHandshakeClusterName(addrs2[0].Attributes)
cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes)
if !ok || cn2 != clusterName2 {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, clusterName1)
}