mirror of https://github.com/grpc/grpc-go.git
269 lines
7.5 KiB
Go
269 lines
7.5 KiB
Go
/*
|
|
*
|
|
* Copyright 2021 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package google
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"google.golang.org/grpc/credentials"
|
|
icredentials "google.golang.org/grpc/internal/credentials"
|
|
"google.golang.org/grpc/internal/grpctest"
|
|
"google.golang.org/grpc/internal/xds"
|
|
"google.golang.org/grpc/resolver"
|
|
)
|
|
|
|
var defaultTestTimeout = 10 * time.Second
|
|
|
|
type s struct {
|
|
grpctest.Tester
|
|
}
|
|
|
|
func Test(t *testing.T) {
|
|
grpctest.RunSubTests(t, s{})
|
|
}
|
|
|
|
type testCreds struct {
|
|
credentials.TransportCredentials
|
|
typ string
|
|
}
|
|
|
|
func (c *testCreds) ClientHandshake(context.Context, string, net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
|
return nil, &testAuthInfo{typ: c.typ}, nil
|
|
}
|
|
|
|
func (c *testCreds) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
|
return nil, &testAuthInfo{typ: c.typ}, nil
|
|
}
|
|
|
|
type testAuthInfo struct {
|
|
typ string
|
|
}
|
|
|
|
func (t *testAuthInfo) AuthType() string {
|
|
return t.typ
|
|
}
|
|
|
|
type testPerRPCCreds struct {
|
|
md map[string]string
|
|
}
|
|
|
|
func (c *testPerRPCCreds) RequireTransportSecurity() bool {
|
|
return true
|
|
}
|
|
|
|
func (c *testPerRPCCreds) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
|
return c.md, nil
|
|
}
|
|
|
|
var (
|
|
testTLS = &testCreds{typ: "tls"}
|
|
testALTS = &testCreds{typ: "alts"}
|
|
)
|
|
|
|
func overrideNewCredsFuncs() func() {
|
|
origNewTLS := newTLS
|
|
newTLS = func() credentials.TransportCredentials {
|
|
return testTLS
|
|
}
|
|
origNewALTS := newALTS
|
|
newALTS = func() credentials.TransportCredentials {
|
|
return testALTS
|
|
}
|
|
origNewADC := newADC
|
|
newADC = func(context.Context) (credentials.PerRPCCredentials, error) {
|
|
// We do not use perRPC creds in this test. It is safe to return nil here.
|
|
return nil, nil
|
|
}
|
|
|
|
return func() {
|
|
newTLS = origNewTLS
|
|
newALTS = origNewALTS
|
|
newADC = origNewADC
|
|
}
|
|
}
|
|
|
|
// TestClientHandshakeBasedOnClusterName that by default (without switching
|
|
// modes), ClientHandshake does either tls or alts base on the cluster name in
|
|
// attributes.
|
|
func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
defer overrideNewCredsFuncs()()
|
|
for bundleTyp, tc := range map[string]credentials.Bundle{
|
|
"defaultCredsWithOptions": NewDefaultCredentialsWithOptions(DefaultCredentialsOptions{}),
|
|
"defaultCreds": NewDefaultCredentials(),
|
|
"computeCreds": NewComputeEngineCredentials(),
|
|
} {
|
|
tests := []struct {
|
|
name string
|
|
ctx context.Context
|
|
wantTyp string
|
|
}{
|
|
{
|
|
name: "no cluster name",
|
|
ctx: ctx,
|
|
wantTyp: "tls",
|
|
},
|
|
{
|
|
name: "with non-CFE cluster name",
|
|
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
|
|
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
|
|
}),
|
|
// non-CFE backends should use alts.
|
|
wantTyp: "alts",
|
|
},
|
|
{
|
|
name: "with CFE cluster name",
|
|
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
|
|
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
|
|
}),
|
|
// CFE should use tls.
|
|
wantTyp: "tls",
|
|
},
|
|
{
|
|
name: "with xdstp CFE cluster name",
|
|
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
|
|
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
|
|
}),
|
|
// CFE should use tls.
|
|
wantTyp: "tls",
|
|
},
|
|
{
|
|
name: "with xdstp non-CFE cluster name",
|
|
ctx: icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{
|
|
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
|
|
}),
|
|
// non-CFE should use atls.
|
|
wantTyp: "alts",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(bundleTyp+" "+tt.name, func(t *testing.T) {
|
|
_, info, err := tc.TransportCredentials().ClientHandshake(tt.ctx, "", nil)
|
|
if err != nil {
|
|
t.Fatalf("ClientHandshake failed: %v", err)
|
|
}
|
|
if gotType := info.AuthType(); gotType != tt.wantTyp {
|
|
t.Fatalf("unexpected authtype: %v, want: %v", gotType, tt.wantTyp)
|
|
}
|
|
|
|
_, infoServer, err := tc.TransportCredentials().ServerHandshake(nil)
|
|
if err != nil {
|
|
t.Fatalf("ClientHandshake failed: %v", err)
|
|
}
|
|
// ServerHandshake should always do TLS.
|
|
if gotType := infoServer.AuthType(); gotType != "tls" {
|
|
t.Fatalf("unexpected server authtype: %v, want: %v", gotType, "tls")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDefaultCredentialsWithOptions(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
|
defer cancel()
|
|
md1 := map[string]string{"foo": "tls"}
|
|
md2 := map[string]string{"foo": "alts"}
|
|
tests := []struct {
|
|
desc string
|
|
defaultCredsOpts DefaultCredentialsOptions
|
|
authInfo credentials.AuthInfo
|
|
wantedMetadata map[string]string
|
|
}{
|
|
{
|
|
desc: "no ALTSPerRPCCreds with tls channel",
|
|
defaultCredsOpts: DefaultCredentialsOptions{
|
|
PerRPCCreds: &testPerRPCCreds{
|
|
md: md1,
|
|
},
|
|
},
|
|
authInfo: &testAuthInfo{typ: "tls"},
|
|
wantedMetadata: md1,
|
|
},
|
|
{
|
|
desc: "no ALTSPerRPCCreds with alts channel",
|
|
defaultCredsOpts: DefaultCredentialsOptions{
|
|
PerRPCCreds: &testPerRPCCreds{
|
|
md: md1,
|
|
},
|
|
},
|
|
authInfo: &testAuthInfo{typ: "alts"},
|
|
wantedMetadata: md1,
|
|
},
|
|
{
|
|
desc: "ALTSPerRPCCreds specified with tls channel",
|
|
defaultCredsOpts: DefaultCredentialsOptions{
|
|
PerRPCCreds: &testPerRPCCreds{
|
|
md: md1,
|
|
},
|
|
ALTSPerRPCCreds: &testPerRPCCreds{
|
|
md: md2,
|
|
},
|
|
},
|
|
authInfo: &testAuthInfo{typ: "tls"},
|
|
wantedMetadata: md1,
|
|
},
|
|
{
|
|
desc: "ALTSPerRPCCreds specified with alts channel",
|
|
defaultCredsOpts: DefaultCredentialsOptions{
|
|
PerRPCCreds: &testPerRPCCreds{
|
|
md: md1,
|
|
},
|
|
ALTSPerRPCCreds: &testPerRPCCreds{
|
|
md: md2,
|
|
},
|
|
},
|
|
authInfo: &testAuthInfo{typ: "alts"},
|
|
wantedMetadata: md2,
|
|
},
|
|
{
|
|
desc: "ALTSPerRPCCreds specified with unknown channel",
|
|
defaultCredsOpts: DefaultCredentialsOptions{
|
|
PerRPCCreds: &testPerRPCCreds{
|
|
md: md1,
|
|
},
|
|
ALTSPerRPCCreds: &testPerRPCCreds{
|
|
md: md2,
|
|
},
|
|
},
|
|
authInfo: &testAuthInfo{typ: "foo"},
|
|
wantedMetadata: md1,
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.desc, func(t *testing.T) {
|
|
bundle := NewDefaultCredentialsWithOptions(tc.defaultCredsOpts)
|
|
ri := credentials.RequestInfo{AuthInfo: tc.authInfo}
|
|
ctx := credentials.NewContextWithRequestInfo(ctx, ri)
|
|
got, err := bundle.PerRPCCredentials().GetRequestMetadata(ctx, "uri")
|
|
if err != nil {
|
|
t.Fatalf("Bundle's PerRPCCredentials().GetRequestMetadata() unexpected error = %v", err)
|
|
}
|
|
if diff := cmp.Diff(got, tc.wantedMetadata); diff != "" {
|
|
t.Errorf("Unexpected request metadata from bundle's PerRPCCredentials. Diff (-got +want):\n%v", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|