diff --git a/Documentation/grpc-auth-support.md b/Documentation/grpc-auth-support.md index c8c2a556c..3c47137cc 100644 --- a/Documentation/grpc-auth-support.md +++ b/Documentation/grpc-auth-support.md @@ -58,6 +58,11 @@ Note, the OAuth2 implementation of `grpc.PerRPCCredentials` requires a client to [grpc.WithTransportCredentials](https://godoc.org/google.golang.org/grpc#WithTransportCredentials) to prevent any insecure transmission of tokens. +The default behaviour is to strip the gRPC method from the endpoint that is passed to the +`GetRequestMetadata` method of `PerRPCCredentials`. However, this can be overridden to pass +the entire endpoint as required for some JWT implementations by setting the +`GRPC_AUDIENCE_IS_FULL_PATH` environment variable to `"true"`. + # Authenticating with Google ## Google Compute Engine (GCE) diff --git a/internal/envconfig/envconfig.go b/internal/envconfig/envconfig.go index 65fde7dc2..fd4e0d376 100644 --- a/internal/envconfig/envconfig.go +++ b/internal/envconfig/envconfig.go @@ -80,6 +80,13 @@ var ( // ALTSHandshakerKeepaliveParams is set if we should add the // KeepaliveParams when dial the ALTS handshaker service. ALTSHandshakerKeepaliveParams = boolFromEnv("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false) + + // AudienceIsFullPath is set if the user expects that the endpoint that + // is passed to the credential helper called by GetRequestMetadata contains + // the full URL rather than excluding the method. This is required as there + // are competing specifications around what endpoint should be specified for + // a JWT audience. + AudienceIsFullPath = boolFromEnv("GRPC_AUDIENCE_IS_FULL_PATH", false) ) func boolFromEnv(envVar string, def bool) bool { diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 5467fe971..2835d242b 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -39,6 +39,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" icredentials "google.golang.org/grpc/internal/credentials" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" @@ -645,6 +646,10 @@ func (t *http2Client) createAudience(callHdr *CallHdr) string { // Construct URI required to get auth request metadata. // Omit port if it is the default one. host := strings.TrimSuffix(callHdr.Host, ":443") + if envconfig.AudienceIsFullPath { + return "https://" + host + callHdr.Method + } + pos := strings.LastIndex(callHdr.Method, "/") if pos == -1 { pos = len(callHdr.Method) diff --git a/test/creds_test.go b/test/creds_test.go index bedafa5b7..cc4bcb886 100644 --- a/test/creds_test.go +++ b/test/creds_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" @@ -543,3 +544,58 @@ func (s) TestServerCredsDispatch(t *testing.T) { t.Errorf("Read() = %v, %v; want n>0, ", n, err) } } + +type audienceTestCreds struct{} + +func (a *audienceTestCreds) GetRequestMetadata(_ context.Context, uri ...string) (map[string]string, error) { + var endpoint string + if len(uri) > 0 { + endpoint = uri[0] + } + return nil, status.Error(codes.Unknown, endpoint) +} + +func (a *audienceTestCreds) RequireTransportSecurity() bool { return false } + +func (s) TestGRPCMethodInAudienceWhenEnvironmentSet(t *testing.T) { + te := newTest(t, env{name: "method-in-audience", network: "tcp"}) + te.userAgent = testAppUA + te.startServer(&testServer{security: te.e.security}) + defer te.tearDown() + + cc := te.clientConn(grpc.WithPerRPCCredentials(&audienceTestCreds{})) + tc := testgrpc.NewTestServiceClient(cc) + + tests := []struct { + name string + endpoint string + audienceIsFullPath bool + }{ + { + name: "full-path-sent", + endpoint: fmt.Sprintf("https://%s/grpc.testing.TestService/EmptyCall", te.srvAddr), + audienceIsFullPath: true, + }, + { + name: "method-omitted", + endpoint: fmt.Sprintf("https://%s/grpc.testing.TestService", te.srvAddr), + audienceIsFullPath: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutils.SetEnvConfig(t, &envconfig.AudienceIsFullPath, test.audienceIsFullPath) + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); status.Convert(err).Message() != test.endpoint { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, test.endpoint) + } + + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); status.Convert(err).Message() != test.endpoint { + t.Fatalf("ss.client.EmptyCall(_, _) = _, %v; want _, _.Message()=%q", err, test.endpoint) + } + }) + } +}