mirror of https://github.com/grpc/grpc-go.git
credentials/alts: Add ServiceOption for server-side ALTS creation (#2009)
* Move handshaker_service_address flag to binaries
This commit is contained in:
parent
4172bfc25e
commit
75d37eff66
|
@ -40,6 +40,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// hypervisorHandshakerServiceAddress represents the default ALTS gRPC
|
||||||
|
// handshaker service address in the hypervisor.
|
||||||
|
hypervisorHandshakerServiceAddress = "metadata.google.internal:8080"
|
||||||
// defaultTimeout specifies the server handshake timeout.
|
// defaultTimeout specifies the server handshake timeout.
|
||||||
defaultTimeout = 30.0 * time.Second
|
defaultTimeout = 30.0 * time.Second
|
||||||
// The following constants specify the minimum and maximum acceptable
|
// The following constants specify the minimum and maximum acceptable
|
||||||
|
@ -95,39 +98,71 @@ type ClientOptions struct {
|
||||||
// TargetServiceAccounts contains a list of expected target service
|
// TargetServiceAccounts contains a list of expected target service
|
||||||
// accounts.
|
// accounts.
|
||||||
TargetServiceAccounts []string
|
TargetServiceAccounts []string
|
||||||
|
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
|
||||||
|
// address to connect to.
|
||||||
|
HandshakerServiceAddress string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultClientOptions creates a new ClientOptions object with the default
|
||||||
|
// values.
|
||||||
|
func DefaultClientOptions() *ClientOptions {
|
||||||
|
return &ClientOptions{
|
||||||
|
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerOptions contains the server-side options of an ALTS channel. These
|
||||||
|
// options will be passed to the underlying ALTS handshaker.
|
||||||
|
type ServerOptions struct {
|
||||||
|
// HandshakerServiceAddress represents the ALTS handshaker gRPC service
|
||||||
|
// address to connect to.
|
||||||
|
HandshakerServiceAddress string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultServerOptions creates a new ServerOptions object with the default
|
||||||
|
// values.
|
||||||
|
func DefaultServerOptions() *ServerOptions {
|
||||||
|
return &ServerOptions{
|
||||||
|
HandshakerServiceAddress: hypervisorHandshakerServiceAddress,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// altsTC is the credentials required for authenticating a connection using ALTS.
|
// altsTC is the credentials required for authenticating a connection using ALTS.
|
||||||
// It implements credentials.TransportCredentials interface.
|
// It implements credentials.TransportCredentials interface.
|
||||||
type altsTC struct {
|
type altsTC struct {
|
||||||
info *credentials.ProtocolInfo
|
info *credentials.ProtocolInfo
|
||||||
hsAddr string
|
hsAddr string
|
||||||
side core.Side
|
side core.Side
|
||||||
accounts []string
|
accounts []string
|
||||||
|
hsAddress string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
// NewClientCreds constructs a client-side ALTS TransportCredentials object.
|
||||||
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
|
func NewClientCreds(opts *ClientOptions) credentials.TransportCredentials {
|
||||||
return newALTS(core.ClientSide, opts.TargetServiceAccounts)
|
return newALTS(core.ClientSide, opts.TargetServiceAccounts, opts.HandshakerServiceAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
|
// NewServerCreds constructs a server-side ALTS TransportCredentials object.
|
||||||
func NewServerCreds() credentials.TransportCredentials {
|
func NewServerCreds(opts *ServerOptions) credentials.TransportCredentials {
|
||||||
return newALTS(core.ServerSide, nil)
|
return newALTS(core.ServerSide, nil, opts.HandshakerServiceAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newALTS(side core.Side, accounts []string) credentials.TransportCredentials {
|
func newALTS(side core.Side, accounts []string, hsAddress string) credentials.TransportCredentials {
|
||||||
once.Do(func() {
|
once.Do(func() {
|
||||||
vmOnGCP = isRunningOnGCP()
|
vmOnGCP = isRunningOnGCP()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if hsAddress == "" {
|
||||||
|
hsAddress = hypervisorHandshakerServiceAddress
|
||||||
|
}
|
||||||
return &altsTC{
|
return &altsTC{
|
||||||
info: &credentials.ProtocolInfo{
|
info: &credentials.ProtocolInfo{
|
||||||
SecurityProtocol: "alts",
|
SecurityProtocol: "alts",
|
||||||
SecurityVersion: "1.0",
|
SecurityVersion: "1.0",
|
||||||
},
|
},
|
||||||
side: side,
|
side: side,
|
||||||
accounts: accounts,
|
accounts: accounts,
|
||||||
|
hsAddress: hsAddress,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +173,7 @@ func (g *altsTC) ClientHandshake(ctx context.Context, addr string, rawConn net.C
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connecting to ALTS handshaker service.
|
// Connecting to ALTS handshaker service.
|
||||||
hsConn, err := service.Dial()
|
hsConn, err := service.Dial(g.hsAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
@ -191,7 +226,7 @@ func (g *altsTC) ServerHandshake(rawConn net.Conn) (_ net.Conn, _ credentials.Au
|
||||||
return nil, nil, ErrUntrustedPlatform
|
return nil, nil, ErrUntrustedPlatform
|
||||||
}
|
}
|
||||||
// Connecting to ALTS handshaker service.
|
// Connecting to ALTS handshaker service.
|
||||||
hsConn, err := service.Dial()
|
hsConn, err := service.Dial(g.hsAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@ import (
|
||||||
func TestInfoServerName(t *testing.T) {
|
func TestInfoServerName(t *testing.T) {
|
||||||
// This is not testing any handshaker functionality, so it's fine to only
|
// This is not testing any handshaker functionality, so it's fine to only
|
||||||
// use NewServerCreds and not NewClientCreds.
|
// use NewServerCreds and not NewClientCreds.
|
||||||
alts := NewServerCreds()
|
alts := NewServerCreds(DefaultServerOptions())
|
||||||
if got, want := alts.Info().ServerName, ""; got != want {
|
if got, want := alts.Info().ServerName, ""; got != want {
|
||||||
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
|
t.Fatalf("%v.Info().ServerName = %v, want %v", alts, got, want)
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ func TestOverrideServerName(t *testing.T) {
|
||||||
wantServerName := "server.name"
|
wantServerName := "server.name"
|
||||||
// This is not testing any handshaker functionality, so it's fine to only
|
// This is not testing any handshaker functionality, so it's fine to only
|
||||||
// use NewServerCreds and not NewClientCreds.
|
// use NewServerCreds and not NewClientCreds.
|
||||||
c := NewServerCreds()
|
c := NewServerCreds(DefaultServerOptions())
|
||||||
c.OverrideServerName(wantServerName)
|
c.OverrideServerName(wantServerName)
|
||||||
if got, want := c.Info().ServerName, wantServerName; got != want {
|
if got, want := c.Info().ServerName, wantServerName; got != want {
|
||||||
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
|
t.Fatalf("c.Info().ServerName = %v, want %v", got, want)
|
||||||
|
@ -49,7 +49,7 @@ func TestClone(t *testing.T) {
|
||||||
wantServerName := "server.name"
|
wantServerName := "server.name"
|
||||||
// This is not testing any handshaker functionality, so it's fine to only
|
// This is not testing any handshaker functionality, so it's fine to only
|
||||||
// use NewServerCreds and not NewClientCreds.
|
// use NewServerCreds and not NewClientCreds.
|
||||||
c := NewServerCreds()
|
c := NewServerCreds(DefaultServerOptions())
|
||||||
c.OverrideServerName(wantServerName)
|
c.OverrideServerName(wantServerName)
|
||||||
cc := c.Clone()
|
cc := c.Clone()
|
||||||
if got, want := cc.Info().ServerName, wantServerName; got != want {
|
if got, want := cc.Info().ServerName, wantServerName; got != want {
|
||||||
|
@ -67,7 +67,7 @@ func TestClone(t *testing.T) {
|
||||||
func TestInfo(t *testing.T) {
|
func TestInfo(t *testing.T) {
|
||||||
// This is not testing any handshaker functionality, so it's fine to only
|
// This is not testing any handshaker functionality, so it's fine to only
|
||||||
// use NewServerCreds and not NewClientCreds.
|
// use NewServerCreds and not NewClientCreds.
|
||||||
c := NewServerCreds()
|
c := NewServerCreds(DefaultServerOptions())
|
||||||
info := c.Info()
|
info := c.Info()
|
||||||
if got, want := info.ProtocolVersion, ""; got != want {
|
if got, want := info.ProtocolVersion, ""; got != want {
|
||||||
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
|
t.Errorf("info.ProtocolVersion=%v, want %v", got, want)
|
||||||
|
|
|
@ -21,16 +21,12 @@
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
grpc "google.golang.org/grpc"
|
grpc "google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// hsServiceAddr specifies the default ALTS handshaker service address in
|
|
||||||
// the hypervisor.
|
|
||||||
hsServiceAddr = flag.String("handshaker_service_address", "metadata.google.internal:8080", "ALTS handshaker gRPC service address")
|
|
||||||
// hsConn represents a connection to hypervisor handshaker service.
|
// hsConn represents a connection to hypervisor handshaker service.
|
||||||
hsConn *grpc.ClientConn
|
hsConn *grpc.ClientConn
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
@ -42,8 +38,8 @@ type dialer func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, erro
|
||||||
|
|
||||||
// Dial dials the handshake service in the hypervisor. If a connection has
|
// Dial dials the handshake service in the hypervisor. If a connection has
|
||||||
// already been established, this function returns it. Otherwise, a new
|
// already been established, this function returns it. Otherwise, a new
|
||||||
// connection is created,
|
// connection is created.
|
||||||
func Dial() (*grpc.ClientConn, error) {
|
func Dial(hsAddress string) (*grpc.ClientConn, error) {
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
@ -51,7 +47,7 @@ func Dial() (*grpc.ClientConn, error) {
|
||||||
// Create a new connection to the handshaker service. Note that
|
// Create a new connection to the handshaker service. Note that
|
||||||
// this connection stays open until the application is closed.
|
// this connection stays open until the application is closed.
|
||||||
var err error
|
var err error
|
||||||
hsConn, err = hsDialer(*hsServiceAddr, grpc.WithInsecure())
|
hsConn, err = hsDialer(hsAddress, grpc.WithInsecure())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,11 @@ import (
|
||||||
grpc "google.golang.org/grpc"
|
grpc "google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The address is irrelevant in this test.
|
||||||
|
testAddress = "some_address"
|
||||||
|
)
|
||||||
|
|
||||||
func TestDial(t *testing.T) {
|
func TestDial(t *testing.T) {
|
||||||
defer func() func() {
|
defer func() func() {
|
||||||
temp := hsDialer
|
temp := hsDialer
|
||||||
|
@ -39,24 +44,24 @@ func TestDial(t *testing.T) {
|
||||||
hsConn = nil
|
hsConn = nil
|
||||||
|
|
||||||
// First call to Dial, it should create set hsConn.
|
// First call to Dial, it should create set hsConn.
|
||||||
conn1, err := Dial()
|
conn1, err := Dial(testAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("first call to Dial failed: %v", err)
|
t.Fatalf("first call to Dial failed: %v", err)
|
||||||
}
|
}
|
||||||
if conn1 == nil {
|
if conn1 == nil {
|
||||||
t.Fatal("first call to Dial()=(nil, _), want not nil")
|
t.Fatal("first call to Dial(_)=(nil, _), want not nil")
|
||||||
}
|
}
|
||||||
if got, want := hsConn, conn1; got != want {
|
if got, want := hsConn, conn1; got != want {
|
||||||
t.Fatalf("hsConn=%v, want %v", got, want)
|
t.Fatalf("hsConn=%v, want %v", got, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Second call to Dial() should return conn1 above.
|
// Second call to Dial should return conn1 above.
|
||||||
conn2, err := Dial()
|
conn2, err := Dial(testAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("second call to Dial() failed: %v", err)
|
t.Fatalf("second call to Dial(_) failed: %v", err)
|
||||||
}
|
}
|
||||||
if got, want := conn2, conn1; got != want {
|
if got, want := conn2, conn1; got != want {
|
||||||
t.Fatalf("second call to Dial()=(%v, _), want (%v,. _)", got, want)
|
t.Fatalf("second call to Dial(_)=(%v, _), want (%v,. _)", got, want)
|
||||||
}
|
}
|
||||||
if got, want := hsConn, conn1; got != want {
|
if got, want := hsConn, conn1; got != want {
|
||||||
t.Fatalf("hsConn=%v, want %v", got, want)
|
t.Fatalf("hsConn=%v, want %v", got, want)
|
||||||
|
|
|
@ -35,13 +35,18 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||||
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
|
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
altsTC := alts.NewClientCreds(&alts.ClientOptions{})
|
opts := alts.DefaultClientOptions()
|
||||||
|
if *hsAddr != "" {
|
||||||
|
opts.HandshakerServiceAddress = *hsAddr
|
||||||
|
}
|
||||||
|
altsTC := alts.NewClientCreds(opts)
|
||||||
// Block until the server is ready.
|
// Block until the server is ready.
|
||||||
conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock())
|
conn, err := grpc.Dial(*serverAddr, grpc.WithTransportCredentials(altsTC), grpc.WithBlock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -31,6 +31,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
hsAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||||
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
|
serverAddr = flag.String("server_address", ":8080", "The port on which the server is listening")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,7 +42,11 @@ func main() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err)
|
grpclog.Fatalf("gRPC Server: failed to start the server at %v: %v", *serverAddr, err)
|
||||||
}
|
}
|
||||||
altsTC := alts.NewServerCreds()
|
opts := alts.DefaultServerOptions()
|
||||||
|
if *hsAddr != "" {
|
||||||
|
opts.HandshakerServiceAddress = *hsAddr
|
||||||
|
}
|
||||||
|
altsTC := alts.NewServerCreds(opts)
|
||||||
grpcServer := grpc.NewServer(grpc.Creds(altsTC))
|
grpcServer := grpc.NewServer(grpc.Creds(altsTC))
|
||||||
testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer())
|
testpb.RegisterTestServiceServer(grpcServer, interop.NewTestServer())
|
||||||
grpcServer.Serve(lis)
|
grpcServer.Serve(lis)
|
||||||
|
|
|
@ -37,6 +37,7 @@ var (
|
||||||
caFile = flag.String("ca_file", "", "The file containning the CA root cert file")
|
caFile = flag.String("ca_file", "", "The file containning the CA root cert file")
|
||||||
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true")
|
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true")
|
||||||
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
||||||
|
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||||
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
|
testCA = flag.Bool("use_test_ca", false, "Whether to replace platform root CAs with test CA as the CA root")
|
||||||
serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file")
|
serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file")
|
||||||
oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens")
|
oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens")
|
||||||
|
@ -110,7 +111,11 @@ func main() {
|
||||||
opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope))))
|
opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(interop.GetToken(*serviceAccountKeyFile, *oauthScope))))
|
||||||
}
|
}
|
||||||
} else if *useALTS {
|
} else if *useALTS {
|
||||||
altsTC := alts.NewClientCreds(&alts.ClientOptions{})
|
altsOpts := alts.DefaultClientOptions()
|
||||||
|
if *altsHSAddr != "" {
|
||||||
|
altsOpts.HandshakerServiceAddress = *altsHSAddr
|
||||||
|
}
|
||||||
|
altsTC := alts.NewClientCreds(altsOpts)
|
||||||
opts = append(opts, grpc.WithTransportCredentials(altsTC))
|
opts = append(opts, grpc.WithTransportCredentials(altsTC))
|
||||||
} else {
|
} else {
|
||||||
opts = append(opts, grpc.WithInsecure())
|
opts = append(opts, grpc.WithInsecure())
|
||||||
|
|
|
@ -33,11 +33,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
|
useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP")
|
||||||
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
useALTS = flag.Bool("use_alts", false, "Connection uses ALTS if true (this option can only be used on GCP)")
|
||||||
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
|
altsHSAddr = flag.String("alts_handshaker_service_address", "", "ALTS handshaker gRPC service address")
|
||||||
keyFile = flag.String("tls_key_file", "", "The TLS key file")
|
certFile = flag.String("tls_cert_file", "", "The TLS cert file")
|
||||||
port = flag.Int("port", 10000, "The server port")
|
keyFile = flag.String("tls_key_file", "", "The TLS key file")
|
||||||
|
port = flag.Int("port", 10000, "The server port")
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@ -64,7 +65,11 @@ func main() {
|
||||||
}
|
}
|
||||||
opts = append(opts, grpc.Creds(creds))
|
opts = append(opts, grpc.Creds(creds))
|
||||||
} else if *useALTS {
|
} else if *useALTS {
|
||||||
altsTC := alts.NewServerCreds()
|
altsOpts := alts.DefaultServerOptions()
|
||||||
|
if *altsHSAddr != "" {
|
||||||
|
altsOpts.HandshakerServiceAddress = *altsHSAddr
|
||||||
|
}
|
||||||
|
altsTC := alts.NewServerCreds(altsOpts)
|
||||||
opts = append(opts, grpc.Creds(altsTC))
|
opts = append(opts, grpc.Creds(altsTC))
|
||||||
}
|
}
|
||||||
server := grpc.NewServer(opts...)
|
server := grpc.NewServer(opts...)
|
||||||
|
|
Loading…
Reference in New Issue