credentials/alts: Add ServiceOption for server-side ALTS creation (#2009)

* Move handshaker_service_address flag to binaries
This commit is contained in:
Cesar Ghali 2018-04-23 11:11:20 -07:00 committed by GitHub
parent 4172bfc25e
commit 75d37eff66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 94 additions and 38 deletions

View File

@ -40,6 +40,9 @@ import (
)
const (
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
// handshaker service address in the hypervisor.
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
// defaultTimeout specifies the server handshake timeout.
defaultTimeout = 30.0 * time.Second
// The following constants specify the minimum and maximum acceptable
@ -95,39 +98,71 @@ type ClientOptions struct {
// TargetServiceAccounts contains a list of expected target service
// accounts.
TargetServiceAccounts []string
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultClientOptions creates a new ClientOptions object with the default
// values.
func DefaultClientOptions() *ClientOptions {
return &ClientOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// ServerOptions contains the server-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ServerOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultServerOptions creates a new ServerOptions object with the default
// values.
func DefaultServerOptions() *ServerOptions {
return &ServerOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface.
type altsTC struct {
info *credentials.ProtocolInfo
hsAddr string
side core.Side
accounts []string
info *credentials.ProtocolInfo
hsAddr string
side core.Side
accounts []string
hsAddress string
}
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
return newALTS(core.ClientSide, opts.TargetServiceAccounts)
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
}
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
func NewServerCreds() credentials.TransportCredentials {
return newALTS(core.ServerSide, nil)
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
}
func newALTS(side core.Side, accounts []string) credentials.TransportCredentials {
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
once.Do(func() {
vmOnGCP = isRunningOnGCP()
})
if hsAddress == "" {
hsAddress = hypervisorHandshakerServiceAddress
}
return &altsTC{
info: &credentials.ProtocolInfo{
SecurityProtocol: "alts",
SecurityVersion: "1.0",
},
side: side,
accounts: accounts,
side: side,
accounts: accounts,
hsAddress: hsAddress,
}
}
@ -138,7 +173,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial()
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}
@ -191,7 +226,7 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au
return nil, nil, ErrUntrustedPlatform
}
// Connecting to ALTS handshaker service.
hsConn, err := service.Dial()
hsConn, err := service.Dial(g.hsAddress)
if err != nil {
return nil, nil, err
}

View File

@ -28,7 +28,7 @@ import (
func TestInfoServerName(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
alts := NewServerCreds()
alts := NewServerCreds(DefaultServerOptions())
if got, want := alts.Info().ServerName, ""; got != want {
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
}
@ -38,7 +38,7 @@ func TestOverrideServerName(t *testing.T) {
wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds()
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
@ -49,7 +49,7 @@ func TestClone(t *testing.T) {
wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds()
c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName)
cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want {
@ -67,7 +67,7 @@ func TestClone(t *testing.T) {
func TestInfo(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds.
c := NewServerCreds()
c := NewServerCreds(DefaultServerOptions())
info := c.Info()
if got, want := info.ProtocolVersion, ""; got != want {
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)

View File

@ -21,16 +21,12 @@
package service
import (
"flag"
"sync"
grpc "google.golang.org/grpc"
)
var (
// hsServiceAddr specifies the default ALTS handshaker service address in
// the hypervisor.
hsServiceAddr = flag.String("handshaker_service_address", "metadata.google.internal:8080", "ALTS handshaker gRPC service address")
// hsConn represents a connection to hypervisor handshaker service.
hsConn *grpc.ClientConn
mu sync.Mutex
@ -42,8 +38,8 @@ type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, erro
// Dial dials the handshake service in the hypervisor. If a connection has
// already been established, this function returns it. Otherwise, a new
// connection is created,
func Dial() (*grpc.ClientConn, error) {
// connection is created.
func Dial(hsAddress string) (*grpc.ClientConn, error) {
mu.Lock()
defer mu.Unlock()
@ -51,7 +47,7 @@ func Dial() (*grpc.ClientConn, error) {
// Create a new connection to the handshaker service. Note that
// this connection stays open until the application is closed.
var err error
hsConn, err = hsDialer(*hsServiceAddr, grpc.WithInsecure())
hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
if err != nil {
return nil, err
}

View File

@ -24,6 +24,11 @@ import (
grpc "google.golang.org/grpc"
)
const (
// The address is irrelevant in this test.
testAddress = "some_address"
)
func TestDial(t *testing.T) {
defer func() func() {
temp := hsDialer
@ -39,24 +44,24 @@ func TestDial(t *testing.T) {
hsConn = nil
// First call to Dial, it should create set hsConn.
conn1, err := Dial()
conn1, err := Dial(testAddress)
if err != nil {
t.Fatalf("first call to Dial failed: %v", err)
}
if conn1 == nil {
t.Fatal("first call to Dial()=(nil, _), want not nil")
t.Fatal("first call to Dial(_)=(nil, _), want not nil")
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)
}
// Second call to Dial() should return conn1 above.
conn2, err := Dial()
// Second call to Dial should return conn1 above.
conn2, err := Dial(testAddress)
if err != nil {
t.Fatalf("second call to Dial() failed: %v", err)
t.Fatalf("second call to Dial(_) failed: %v", err)
}
if got, want := conn2, conn1; got != want {
t.Fatalf("second call to Dial()=(%v, _), want (%v,. _)", got, want)
t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
}
if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want)

View File

@ -35,13 +35,18 @@ const (
)
var (
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
)
func main() {
flag.Parse()
altsTC := alts.NewClientCreds(&alts.ClientOptions{})
opts := alts.DefaultClientOptions()
if *hsAddr != "" {
opts.HandshakerServiceAddress = *hsAddr
}
altsTC := alts.NewClientCreds(opts)
// Block until the server is ready.
conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock())
if err != nil {

View File

@ -31,6 +31,7 @@ import (
)
var (
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
)
@ -41,7 +42,11 @@ func main() {
if err != nil {
grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err)
}
altsTC := alts.NewServerCreds()
opts := alts.DefaultServerOptions()
if *hsAddr != "" {
opts.HandshakerServiceAddress = *hsAddr
}
altsTC := alts.NewServerCreds(opts)
grpcServer := grpc.NewServer(grpc.Creds(altsTC))
testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer())
grpcServer.Serve(lis)

View File

@ -37,6 +37,7 @@ var (
caFile = flag.String("ca_file", "", "The file containning the CA root cert file")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true")
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file")
oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens")
@ -110,7 +111,11 @@ func main() {
opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope))))
}
} else if *useALTS {
altsTC := alts.NewClientCreds(&alts.ClientOptions{})
altsOpts := alts.DefaultClientOptions()
if *altsHSAddr != "" {
altsOpts.HandshakerServiceAddress = *altsHSAddr
}
altsTC := alts.NewClientCreds(altsOpts)
opts = append(opts, grpc.WithTransportCredentials(altsTC))
} else {
opts = append(opts, grpc.WithInsecure())

View File

@ -33,11 +33,12 @@ import (
)
var (
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
keyFile = flag.String("tls_key_file", "", "The TLS key file")
port = flag.Int("port", 10000, "The server port")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
keyFile = flag.String("tls_key_file", "", "The TLS key file")
port = flag.Int("port", 10000, "The server port")
)
func main() {
@ -64,7 +65,11 @@ func main() {
}
opts = append(opts, grpc.Creds(creds))
} else if *useALTS {
altsTC := alts.NewServerCreds()
altsOpts := alts.DefaultServerOptions()
if *altsHSAddr != "" {
altsOpts.HandshakerServiceAddress = *altsHSAddr
}
altsTC := alts.NewServerCreds(altsOpts)
opts = append(opts, grpc.Creds(altsTC))
}
server := grpc.NewServer(opts...)