credentials/alts: Add ServiceOption for server-side ALTS creation (#2009)

* Move handshaker_service_address flag to binaries
This commit is contained in:
Cesar Ghali 2018-04-23 11:11:20 -07:00 committed by GitHub
parent 4172bfc25e
commit 75d37eff66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 94 additions and 38 deletions

View File

@ -40,6 +40,9 @@ import (
) )
const ( const (
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
// handshaker service address in the hypervisor.
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
// defaultTimeout specifies the server handshake timeout. // defaultTimeout specifies the server handshake timeout.
defaultTimeout = 30.0 * time.Second defaultTimeout = 30.0 * time.Second
// The following constants specify the minimum and maximum acceptable // The following constants specify the minimum and maximum acceptable
@ -95,39 +98,71 @@ type ClientOptions struct {
// TargetServiceAccounts contains a list of expected target service // TargetServiceAccounts contains a list of expected target service
// accounts. // accounts.
TargetServiceAccounts []string TargetServiceAccounts []string
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultClientOptions creates a new ClientOptions object with the default
// values.
func DefaultClientOptions() *ClientOptions {
return &ClientOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
}
// ServerOptions contains the server-side options of an ALTS channel. These
// options will be passed to the underlying ALTS handshaker.
type ServerOptions struct {
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
// address to connect to.
HandshakerServiceAddress string
}
// DefaultServerOptions creates a new ServerOptions object with the default
// values.
func DefaultServerOptions() *ServerOptions {
return &ServerOptions{
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
}
} }
// altsTC is the credentials required for authenticating a connection using ALTS. // altsTC is the credentials required for authenticating a connection using ALTS.
// It implements credentials.TransportCredentials interface. // It implements credentials.TransportCredentials interface.
type altsTC struct { type altsTC struct {
info *credentials.ProtocolInfo info *credentials.ProtocolInfo
hsAddr string hsAddr string
side core.Side side core.Side
accounts []string accounts []string
hsAddress string
} }
// NewClientCreds constructs a client-side ALTS TransportCredentials object. // NewClientCreds constructs a client-side ALTS TransportCredentials object.
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials { func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
return newALTS(core.ClientSide, opts.TargetServiceAccounts) return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
} }
// NewServerCreds constructs a server-side ALTS TransportCredentials object. // NewServerCreds constructs a server-side ALTS TransportCredentials object.
func NewServerCreds() credentials.TransportCredentials { func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
return newALTS(core.ServerSide, nil) return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
} }
func newALTS(side core.Side, accounts []string) credentials.TransportCredentials { func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
once.Do(func() { once.Do(func() {
vmOnGCP = isRunningOnGCP() vmOnGCP = isRunningOnGCP()
}) })
if hsAddress == "" {
hsAddress = hypervisorHandshakerServiceAddress
}
return &altsTC{ return &altsTC{
info: &credentials.ProtocolInfo{ info: &credentials.ProtocolInfo{
SecurityProtocol: "alts", SecurityProtocol: "alts",
SecurityVersion: "1.0", SecurityVersion: "1.0",
}, },
side: side, side: side,
accounts: accounts, accounts: accounts,
hsAddress: hsAddress,
} }
} }
@ -138,7 +173,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
} }
// Connecting to ALTS handshaker service. // Connecting to ALTS handshaker service.
hsConn, err := service.Dial() hsConn, err := service.Dial(g.hsAddress)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -191,7 +226,7 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au
return nil, nil, ErrUntrustedPlatform return nil, nil, ErrUntrustedPlatform
} }
// Connecting to ALTS handshaker service. // Connecting to ALTS handshaker service.
hsConn, err := service.Dial() hsConn, err := service.Dial(g.hsAddress)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -28,7 +28,7 @@ import (
func TestInfoServerName(t *testing.T) { func TestInfoServerName(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only // This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds. // use NewServerCreds and not NewClientCreds.
alts := NewServerCreds() alts := NewServerCreds(DefaultServerOptions())
if got, want := alts.Info().ServerName, ""; got != want { if got, want := alts.Info().ServerName, ""; got != want {
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want) t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
} }
@ -38,7 +38,7 @@ func TestOverrideServerName(t *testing.T) {
wantServerName := "server.name" wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only // This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds. // use NewServerCreds and not NewClientCreds.
c := NewServerCreds() c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName) c.OverrideServerName(wantServerName)
if got, want := c.Info().ServerName, wantServerName; got != want { if got, want := c.Info().ServerName, wantServerName; got != want {
t.Fatalf("c.Info().ServerName = %v, want %v", got, want) t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
@ -49,7 +49,7 @@ func TestClone(t *testing.T) {
wantServerName := "server.name" wantServerName := "server.name"
// This is not testing any handshaker functionality, so it's fine to only // This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds. // use NewServerCreds and not NewClientCreds.
c := NewServerCreds() c := NewServerCreds(DefaultServerOptions())
c.OverrideServerName(wantServerName) c.OverrideServerName(wantServerName)
cc := c.Clone() cc := c.Clone()
if got, want := cc.Info().ServerName, wantServerName; got != want { if got, want := cc.Info().ServerName, wantServerName; got != want {
@ -67,7 +67,7 @@ func TestClone(t *testing.T) {
func TestInfo(t *testing.T) { func TestInfo(t *testing.T) {
// This is not testing any handshaker functionality, so it's fine to only // This is not testing any handshaker functionality, so it's fine to only
// use NewServerCreds and not NewClientCreds. // use NewServerCreds and not NewClientCreds.
c := NewServerCreds() c := NewServerCreds(DefaultServerOptions())
info := c.Info() info := c.Info()
if got, want := info.ProtocolVersion, ""; got != want { if got, want := info.ProtocolVersion, ""; got != want {
t.Errorf("info.ProtocolVersion=%v, want %v", got, want) t.Errorf("info.ProtocolVersion=%v, want %v", got, want)

View File

@ -21,16 +21,12 @@
package service package service
import ( import (
"flag"
"sync" "sync"
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
) )
var ( var (
// hsServiceAddr specifies the default ALTS handshaker service address in
// the hypervisor.
hsServiceAddr = flag.String("handshaker_service_address", "metadata.google.internal:8080", "ALTS handshaker gRPC service address")
// hsConn represents a connection to hypervisor handshaker service. // hsConn represents a connection to hypervisor handshaker service.
hsConn *grpc.ClientConn hsConn *grpc.ClientConn
mu sync.Mutex mu sync.Mutex
@ -42,8 +38,8 @@ type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, erro
// Dial dials the handshake service in the hypervisor. If a connection has // Dial dials the handshake service in the hypervisor. If a connection has
// already been established, this function returns it. Otherwise, a new // already been established, this function returns it. Otherwise, a new
// connection is created, // connection is created.
func Dial() (*grpc.ClientConn, error) { func Dial(hsAddress string) (*grpc.ClientConn, error) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -51,7 +47,7 @@ func Dial() (*grpc.ClientConn, error) {
// Create a new connection to the handshaker service. Note that // Create a new connection to the handshaker service. Note that
// this connection stays open until the application is closed. // this connection stays open until the application is closed.
var err error var err error
hsConn, err = hsDialer(*hsServiceAddr, grpc.WithInsecure()) hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -24,6 +24,11 @@ import (
grpc "google.golang.org/grpc" grpc "google.golang.org/grpc"
) )
const (
// The address is irrelevant in this test.
testAddress = "some_address"
)
func TestDial(t *testing.T) { func TestDial(t *testing.T) {
defer func() func() { defer func() func() {
temp := hsDialer temp := hsDialer
@ -39,24 +44,24 @@ func TestDial(t *testing.T) {
hsConn = nil hsConn = nil
// First call to Dial, it should create set hsConn. // First call to Dial, it should create set hsConn.
conn1, err := Dial() conn1, err := Dial(testAddress)
if err != nil { if err != nil {
t.Fatalf("first call to Dial failed: %v", err) t.Fatalf("first call to Dial failed: %v", err)
} }
if conn1 == nil { if conn1 == nil {
t.Fatal("first call to Dial()=(nil, _), want not nil") t.Fatal("first call to Dial(_)=(nil, _), want not nil")
} }
if got, want := hsConn, conn1; got != want { if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want) t.Fatalf("hsConn=%v, want %v", got, want)
} }
// Second call to Dial() should return conn1 above. // Second call to Dial should return conn1 above.
conn2, err := Dial() conn2, err := Dial(testAddress)
if err != nil { if err != nil {
t.Fatalf("second call to Dial() failed: %v", err) t.Fatalf("second call to Dial(_) failed: %v", err)
} }
if got, want := conn2, conn1; got != want { if got, want := conn2, conn1; got != want {
t.Fatalf("second call to Dial()=(%v, _), want (%v,. _)", got, want) t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
} }
if got, want := hsConn, conn1; got != want { if got, want := hsConn, conn1; got != want {
t.Fatalf("hsConn=%v, want %v", got, want) t.Fatalf("hsConn=%v, want %v", got, want)

View File

@ -35,13 +35,18 @@ const (
) )
var ( var (
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening") serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
) )
func main() { func main() {
flag.Parse() flag.Parse()
altsTC := alts.NewClientCreds(&alts.ClientOptions{}) opts := alts.DefaultClientOptions()
if *hsAddr != "" {
opts.HandshakerServiceAddress = *hsAddr
}
altsTC := alts.NewClientCreds(opts)
// Block until the server is ready. // Block until the server is ready.
conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock()) conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock())
if err != nil { if err != nil {

View File

@ -31,6 +31,7 @@ import (
) )
var ( var (
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening") serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
) )
@ -41,7 +42,11 @@ func main() {
if err != nil { if err != nil {
grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err) grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err)
} }
altsTC := alts.NewServerCreds() opts := alts.DefaultServerOptions()
if *hsAddr != "" {
opts.HandshakerServiceAddress = *hsAddr
}
altsTC := alts.NewServerCreds(opts)
grpcServer := grpc.NewServer(grpc.Creds(altsTC)) grpcServer := grpc.NewServer(grpc.Creds(altsTC))
testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer()) testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer())
grpcServer.Serve(lis) grpcServer.Serve(lis)

View File

@ -37,6 +37,7 @@ var (
caFile = flag.String("ca_file", "", "The file containning the CA root cert file") caFile = flag.String("ca_file", "", "The file containning the CA root cert file")
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true") useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true")
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)") useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root") testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file") serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file")
oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens") oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens")
@ -110,7 +111,11 @@ func main() {
opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope)))) opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope))))
} }
} else if *useALTS { } else if *useALTS {
altsTC := alts.NewClientCreds(&alts.ClientOptions{}) altsOpts := alts.DefaultClientOptions()
if *altsHSAddr != "" {
altsOpts.HandshakerServiceAddress = *altsHSAddr
}
altsTC := alts.NewClientCreds(altsOpts)
opts = append(opts, grpc.WithTransportCredentials(altsTC)) opts = append(opts, grpc.WithTransportCredentials(altsTC))
} else { } else {
opts = append(opts, grpc.WithInsecure()) opts = append(opts, grpc.WithInsecure())

View File

@ -33,11 +33,12 @@ import (
) )
var ( var (
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)") useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
certFile = flag.String("tls_cert_file", "", "The TLS cert file") altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
keyFile = flag.String("tls_key_file", "", "The TLS key file") certFile = flag.String("tls_cert_file", "", "The TLS cert file")
port = flag.Int("port", 10000, "The server port") keyFile = flag.String("tls_key_file", "", "The TLS key file")
port = flag.Int("port", 10000, "The server port")
) )
func main() { func main() {
@ -64,7 +65,11 @@ func main() {
} }
opts = append(opts, grpc.Creds(creds)) opts = append(opts, grpc.Creds(creds))
} else if *useALTS { } else if *useALTS {
altsTC := alts.NewServerCreds() altsOpts := alts.DefaultServerOptions()
if *altsHSAddr != "" {
altsOpts.HandshakerServiceAddress = *altsHSAddr
}
altsTC := alts.NewServerCreds(altsOpts)
opts = append(opts, grpc.Creds(altsTC)) opts = append(opts, grpc.Creds(altsTC))
} }
server := grpc.NewServer(opts...) server := grpc.NewServer(opts...)