diff --git a/clientconn.go b/clientconn.go index c4a343b73..fe0a3a044 100644 --- a/clientconn.go +++ b/clientconn.go @@ -60,11 +60,9 @@ func WithClientTLS(creds credentials.TransportAuthenticator) DialOption { } } -// WithComputeEngine returns a DialOption which sets -// credentials which use application default credentials as provided to -// Google Compute Engine. Note that TLS credentials is typically also -// needed. If it is the case, users need to pass WithTLS option too. -func WithComputeEngine(creds credentials.Credentials) DialOption { +// WithPerRPCCredentials returns a DialOption which sets +// credentials which will place auth state on each outbound RPC. +func WithPerRPCCredentials(creds credentials.Credentials) DialOption { return func(o *dialOptions) { o.authOptions = append(o.authOptions, creds) } diff --git a/credentials/credentials.go b/credentials/credentials.go index 5990a49f8..b6c47d125 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -31,25 +31,23 @@ * */ -// Package credentials implements various credentials supported by gRPC library. +// Package credentials implements various credentials supported by gRPC library, +// which encapsulate all the state needed by a client to authenticate with a +// server and make various assertions, e.g., about the client's identity, role, +// or whether it is authorized to make a particular call. package credentials // import "google.golang.org/grpc/credentials" import ( "crypto/tls" "crypto/x509" - "encoding/json" "fmt" "io/ioutil" "net" - "net/http" - "net/url" - "sync" - "time" -) -const ( - metadataServer = "metadata" - serviceAccountPath = "/computeMetadata/v1/instance/service-accounts/default/token" + "golang.org/x/net/context" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/jwt" ) var ( @@ -63,8 +61,11 @@ type Credentials interface { // GetRequestMetadata gets the current request metadata, refreshing // tokens if required. This should be called by the transport layer on // each request, and the data should be populated in headers or other - // context. The operation may do things like refresh tokens. - GetRequestMetadata() (map[string]string, error) + // context. When supported by the underlying implementation, ctx can + // be used for timeout and cancellation. + // TODO(zhaoq): Define the set of the qualified keys instead of leaving + // it as an arbitrary string. + GetRequestMetadata(ctx context.Context) (map[string]string, error) } // TransportAuthenticator defines the common interface all supported transport @@ -98,7 +99,9 @@ type tlsCreds struct { // GetRequestMetadata returns nil, nil since TLS credentials does not have // metadata. -func (c *tlsCreds) GetRequestMetadata() (map[string]string, error) { +// TODO(zhaoq): Define the set of the qualified keys instead of leaving it as an +// arbitrary string. +func (c *tlsCreds) GetRequestMetadata(ctx context.Context) (map[string]string, error) { return nil, nil } @@ -108,7 +111,7 @@ func (c *tlsCreds) Dial(addr string) (_ net.Conn, err error) { if name == "" { name, _, err = net.SplitHostPort(addr) if err != nil { - return nil, fmt.Errorf("failed to parse server address %v", err) + return nil, fmt.Errorf("credentials: failed to parse server address %v", err) } } return tls.Dial("tcp", addr, &tls.Config{ @@ -143,7 +146,7 @@ func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, } cp := x509.NewCertPool() if !cp.AppendCertsFromPEM(b) { - return nil, fmt.Errorf("failed to append certificates") + return nil, fmt.Errorf("credentials: failed to append certificates") } return &tlsCreds{ serverName: serverName, @@ -170,86 +173,68 @@ func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, err }, nil } -type tokenData struct { - accessToken string - expiresIn float64 - tokeType string -} - -type token struct { - accessToken string - expiry time.Time -} - -// expired returns true if there is no access token or the -// access token is expired. -func (t token) expired() bool { - if t.accessToken == "" { - return true - } - if t.expiry.IsZero() { - return false - } - return t.expiry.Before(time.Now()) -} - -// computeEngine uses the Application Default Credentials as provided to Google Compute Engine instances. +// computeEngine represents credentials for the built-in service account for +// the currently running Google Compute Engine (GCE) instance. It uses the +// metadata server to get access tokens. type computeEngine struct { - mu sync.Mutex - t token + ts oauth2.TokenSource } -// GetRequestMetadata returns a refreshed access token. -func (c *computeEngine) GetRequestMetadata() (map[string]string, error) { - c.mu.Lock() - defer c.mu.Unlock() - if c.t.expired() { - if err := c.refresh(); err != nil { - return nil, err - } +func (c computeEngine) GetRequestMetadata(ctx context.Context) (map[string]string, error) { + token, err := c.ts.Token() + if err != nil { + return nil, err } return map[string]string{ - "authorization": "Bearer " + c.t.accessToken, + "authorization": token.TokenType + " " + token.AccessToken, }, nil } -func (c *computeEngine) refresh() error { - // https://developers.google.com/compute/docs/metadata - // v1 requires "Metadata-Flavor: Google" header. - tokenURL := &url.URL{ - Scheme: "http", - Host: metadataServer, - Path: serviceAccountPath, +// NewComputeEngine constructs the credentials that fetches access tokens from +// Google Compute Engine (GCE)'s metadata server. It is only valid to use this +// if your program is running on a GCE instance. +func NewComputeEngine() Credentials { + return computeEngine{ + ts: google.ComputeTokenSource(""), } - req, err := http.NewRequest("GET", tokenURL.String(), nil) - if err != nil { - return err - } - req.Header.Add("Metadata-Flavor", "Google") - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - var td tokenData - err = json.NewDecoder(resp.Body).Decode(&td) - if err != nil { - return err - } - // No need to check td.tokenType. - c.t = token{ - accessToken: td.accessToken, - expiry: time.Now().Add(time.Duration(td.expiresIn) * time.Second), - } - return nil } -// NewComputeEngine constructs a credentials for GCE. -func NewComputeEngine() (Credentials, error) { - creds := &computeEngine{} - // TODO(zhaoq): This is not optimal if refresh() is persistently failed. - if err := creds.refresh(); err != nil { +// serviceAccount represents credentials via JWT signing key. +type serviceAccount struct { + config *jwt.Config +} + +func (s serviceAccount) GetRequestMetadata(ctx context.Context) (map[string]string, error) { + c, ok := ctx.(oauth2.Context) + if !ok { + return nil, fmt.Errorf("credentials: the context %v is invalid", ctx) + } + token, err := s.config.TokenSource(c).Token() + if err != nil { return nil, err } - return creds, nil + return map[string]string{ + "authorization": token.TokenType + " " + token.AccessToken, + }, nil } + +// NewServiceAccountFromKey constructs the credentials using the JSON key slice +// from a Google Developers service account. +func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (Credentials, error) { + config, err := google.JWTConfigFromJSON(jsonKey, scope...) + if err != nil { + return nil, err + } + return serviceAccount{config: config}, nil +} + +// NewServiceAccountFromFile constructs the credentials using the JSON key file +// of a Google Developers service account. +func NewServiceAccountFromFile(keyFile string, scope ...string) (Credentials, error) { + jsonKey, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) + } + return NewServiceAccountFromKey(jsonKey, scope...) +} + diff --git a/transport/http2_client_transport.go b/transport/http2_client_transport.go index e3e482fdb..9fd6283b0 100644 --- a/transport/http2_client_transport.go +++ b/transport/http2_client_transport.go @@ -44,10 +44,10 @@ import ( "github.com/bradfitz/http2" "github.com/bradfitz/http2/hpack" + "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/metadata" - "golang.org/x/net/context" ) // http2Client implements the ClientTransport interface with HTTP2. @@ -218,7 +218,12 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) for _, c := range t.authCreds { - m, err := c.GetRequestMetadata() + m, err := c.GetRequestMetadata(ctx) + select { + case <-ctx.Done(): + return nil, ContextErr(ctx.Err()) + default: + } if err != nil { return nil, StreamErrorf(codes.InvalidArgument, "%v", err) }