mirror of https://github.com/grpc/grpc-go.git
credentials/alts: Add ServiceOption for server-side ALTS creation (#2009)
* Move handshaker_service_address flag to binaries
This commit is contained in:
parent
4172bfc25e
commit
75d37eff66
|
@ -40,6 +40,9 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
|
||||
// handshaker service address in the hypervisor.
|
||||
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
|
||||
// defaultTimeout specifies the server handshake timeout.
|
||||
defaultTimeout = 30.0 * time.Second
|
||||
// The following constants specify the minimum and maximum acceptable
|
||||
|
@ -95,39 +98,71 @@ type ClientOptions struct {
|
|||
// TargetServiceAccounts contains a list of expected target service
|
||||
// accounts.
|
||||
TargetServiceAccounts []string
|
||||
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
|
||||
// address to connect to.
|
||||
HandshakerServiceAddress string
|
||||
}
|
||||
|
||||
// DefaultClientOptions creates a new ClientOptions object with the default
|
||||
// values.
|
||||
func DefaultClientOptions() *ClientOptions {
|
||||
return &ClientOptions{
|
||||
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
|
||||
}
|
||||
}
|
||||
|
||||
// ServerOptions contains the server-side options of an ALTS channel. These
|
||||
// options will be passed to the underlying ALTS handshaker.
|
||||
type ServerOptions struct {
|
||||
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
|
||||
// address to connect to.
|
||||
HandshakerServiceAddress string
|
||||
}
|
||||
|
||||
// DefaultServerOptions creates a new ServerOptions object with the default
|
||||
// values.
|
||||
func DefaultServerOptions() *ServerOptions {
|
||||
return &ServerOptions{
|
||||
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
|
||||
}
|
||||
}
|
||||
|
||||
// altsTC is the credentials required for authenticating a connection using ALTS.
|
||||
// It implements credentials.TransportCredentials interface.
|
||||
type altsTC struct {
|
||||
info *credentials.ProtocolInfo
|
||||
hsAddr string
|
||||
side core.Side
|
||||
accounts []string
|
||||
info *credentials.ProtocolInfo
|
||||
hsAddr string
|
||||
side core.Side
|
||||
accounts []string
|
||||
hsAddress string
|
||||
}
|
||||
|
||||
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
||||
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
|
||||
return newALTS(core.ClientSide, opts.TargetServiceAccounts)
|
||||
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
|
||||
}
|
||||
|
||||
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
|
||||
func NewServerCreds() credentials.TransportCredentials {
|
||||
return newALTS(core.ServerSide, nil)
|
||||
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
|
||||
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
|
||||
}
|
||||
|
||||
func newALTS(side core.Side, accounts []string) credentials.TransportCredentials {
|
||||
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
|
||||
once.Do(func() {
|
||||
vmOnGCP = isRunningOnGCP()
|
||||
})
|
||||
|
||||
if hsAddress == "" {
|
||||
hsAddress = hypervisorHandshakerServiceAddress
|
||||
}
|
||||
return &altsTC{
|
||||
info: &credentials.ProtocolInfo{
|
||||
SecurityProtocol: "alts",
|
||||
SecurityVersion: "1.0",
|
||||
},
|
||||
side: side,
|
||||
accounts: accounts,
|
||||
side: side,
|
||||
accounts: accounts,
|
||||
hsAddress: hsAddress,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,7 +173,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
|
|||
}
|
||||
|
||||
// Connecting to ALTS handshaker service.
|
||||
hsConn, err := service.Dial()
|
||||
hsConn, err := service.Dial(g.hsAddress)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -191,7 +226,7 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au
|
|||
return nil, nil, ErrUntrustedPlatform
|
||||
}
|
||||
// Connecting to ALTS handshaker service.
|
||||
hsConn, err := service.Dial()
|
||||
hsConn, err := service.Dial(g.hsAddress)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
|
@ -28,7 +28,7 @@ import (
|
|||
func TestInfoServerName(t *testing.T) {
|
||||
// This is not testing any handshaker functionality, so it's fine to only
|
||||
// use NewServerCreds and not NewClientCreds.
|
||||
alts := NewServerCreds()
|
||||
alts := NewServerCreds(DefaultServerOptions())
|
||||
if got, want := alts.Info().ServerName, ""; got != want {
|
||||
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ func TestOverrideServerName(t *testing.T) {
|
|||
wantServerName := "server.name"
|
||||
// This is not testing any handshaker functionality, so it's fine to only
|
||||
// use NewServerCreds and not NewClientCreds.
|
||||
c := NewServerCreds()
|
||||
c := NewServerCreds(DefaultServerOptions())
|
||||
c.OverrideServerName(wantServerName)
|
||||
if got, want := c.Info().ServerName, wantServerName; got != want {
|
||||
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
|
||||
|
@ -49,7 +49,7 @@ func TestClone(t *testing.T) {
|
|||
wantServerName := "server.name"
|
||||
// This is not testing any handshaker functionality, so it's fine to only
|
||||
// use NewServerCreds and not NewClientCreds.
|
||||
c := NewServerCreds()
|
||||
c := NewServerCreds(DefaultServerOptions())
|
||||
c.OverrideServerName(wantServerName)
|
||||
cc := c.Clone()
|
||||
if got, want := cc.Info().ServerName, wantServerName; got != want {
|
||||
|
@ -67,7 +67,7 @@ func TestClone(t *testing.T) {
|
|||
func TestInfo(t *testing.T) {
|
||||
// This is not testing any handshaker functionality, so it's fine to only
|
||||
// use NewServerCreds and not NewClientCreds.
|
||||
c := NewServerCreds()
|
||||
c := NewServerCreds(DefaultServerOptions())
|
||||
info := c.Info()
|
||||
if got, want := info.ProtocolVersion, ""; got != want {
|
||||
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
|
||||
|
|
|
@ -21,16 +21,12 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"sync"
|
||||
|
||||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
var (
|
||||
// hsServiceAddr specifies the default ALTS handshaker service address in
|
||||
// the hypervisor.
|
||||
hsServiceAddr = flag.String("handshaker_service_address", "metadata.google.internal:8080", "ALTS handshaker gRPC service address")
|
||||
// hsConn represents a connection to hypervisor handshaker service.
|
||||
hsConn *grpc.ClientConn
|
||||
mu sync.Mutex
|
||||
|
@ -42,8 +38,8 @@ type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, erro
|
|||
|
||||
// Dial dials the handshake service in the hypervisor. If a connection has
|
||||
// already been established, this function returns it. Otherwise, a new
|
||||
// connection is created,
|
||||
func Dial() (*grpc.ClientConn, error) {
|
||||
// connection is created.
|
||||
func Dial(hsAddress string) (*grpc.ClientConn, error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
|
@ -51,7 +47,7 @@ func Dial() (*grpc.ClientConn, error) {
|
|||
// Create a new connection to the handshaker service. Note that
|
||||
// this connection stays open until the application is closed.
|
||||
var err error
|
||||
hsConn, err = hsDialer(*hsServiceAddr, grpc.WithInsecure())
|
||||
hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -24,6 +24,11 @@ import (
|
|||
grpc "google.golang.org/grpc"
|
||||
)
|
||||
|
||||
const (
|
||||
// The address is irrelevant in this test.
|
||||
testAddress = "some_address"
|
||||
)
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
defer func() func() {
|
||||
temp := hsDialer
|
||||
|
@ -39,24 +44,24 @@ func TestDial(t *testing.T) {
|
|||
hsConn = nil
|
||||
|
||||
// First call to Dial, it should create set hsConn.
|
||||
conn1, err := Dial()
|
||||
conn1, err := Dial(testAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("first call to Dial failed: %v", err)
|
||||
}
|
||||
if conn1 == nil {
|
||||
t.Fatal("first call to Dial()=(nil, _), want not nil")
|
||||
t.Fatal("first call to Dial(_)=(nil, _), want not nil")
|
||||
}
|
||||
if got, want := hsConn, conn1; got != want {
|
||||
t.Fatalf("hsConn=%v, want %v", got, want)
|
||||
}
|
||||
|
||||
// Second call to Dial() should return conn1 above.
|
||||
conn2, err := Dial()
|
||||
// Second call to Dial should return conn1 above.
|
||||
conn2, err := Dial(testAddress)
|
||||
if err != nil {
|
||||
t.Fatalf("second call to Dial() failed: %v", err)
|
||||
t.Fatalf("second call to Dial(_) failed: %v", err)
|
||||
}
|
||||
if got, want := conn2, conn1; got != want {
|
||||
t.Fatalf("second call to Dial()=(%v, _), want (%v,. _)", got, want)
|
||||
t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
|
||||
}
|
||||
if got, want := hsConn, conn1; got != want {
|
||||
t.Fatalf("hsConn=%v, want %v", got, want)
|
||||
|
|
|
@ -35,13 +35,18 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
altsTC := alts.NewClientCreds(&alts.ClientOptions{})
|
||||
opts := alts.DefaultClientOptions()
|
||||
if *hsAddr != "" {
|
||||
opts.HandshakerServiceAddress = *hsAddr
|
||||
}
|
||||
altsTC := alts.NewClientCreds(opts)
|
||||
// Block until the server is ready.
|
||||
conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock())
|
||||
if err != nil {
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
|
||||
)
|
||||
|
||||
|
@ -41,7 +42,11 @@ func main() {
|
|||
if err != nil {
|
||||
grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err)
|
||||
}
|
||||
altsTC := alts.NewServerCreds()
|
||||
opts := alts.DefaultServerOptions()
|
||||
if *hsAddr != "" {
|
||||
opts.HandshakerServiceAddress = *hsAddr
|
||||
}
|
||||
altsTC := alts.NewServerCreds(opts)
|
||||
grpcServer := grpc.NewServer(grpc.Creds(altsTC))
|
||||
testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer())
|
||||
grpcServer.Serve(lis)
|
||||
|
|
|
@ -37,6 +37,7 @@ var (
|
|||
caFile = flag.String("ca_file", "", "The file containning the CA root cert file")
|
||||
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true")
|
||||
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
||||
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
|
||||
serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file")
|
||||
oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens")
|
||||
|
@ -110,7 +111,11 @@ func main() {
|
|||
opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope))))
|
||||
}
|
||||
} else if *useALTS {
|
||||
altsTC := alts.NewClientCreds(&alts.ClientOptions{})
|
||||
altsOpts := alts.DefaultClientOptions()
|
||||
if *altsHSAddr != "" {
|
||||
altsOpts.HandshakerServiceAddress = *altsHSAddr
|
||||
}
|
||||
altsTC := alts.NewClientCreds(altsOpts)
|
||||
opts = append(opts, grpc.WithTransportCredentials(altsTC))
|
||||
} else {
|
||||
opts = append(opts, grpc.WithInsecure())
|
||||
|
|
|
@ -33,11 +33,12 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
|
||||
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
||||
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
|
||||
keyFile = flag.String("tls_key_file", "", "The TLS key file")
|
||||
port = flag.Int("port", 10000, "The server port")
|
||||
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
|
||||
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
||||
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
|
||||
keyFile = flag.String("tls_key_file", "", "The TLS key file")
|
||||
port = flag.Int("port", 10000, "The server port")
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@ -64,7 +65,11 @@ func main() {
|
|||
}
|
||||
opts = append(opts, grpc.Creds(creds))
|
||||
} else if *useALTS {
|
||||
altsTC := alts.NewServerCreds()
|
||||
altsOpts := alts.DefaultServerOptions()
|
||||
if *altsHSAddr != "" {
|
||||
altsOpts.HandshakerServiceAddress = *altsHSAddr
|
||||
}
|
||||
altsTC := alts.NewServerCreds(altsOpts)
|
||||
opts = append(opts, grpc.Creds(altsTC))
|
||||
}
|
||||
server := grpc.NewServer(opts...)
|
||||
|
|
Loading…
Reference in New Issue