This commit is contained in:
Dimitar Pavlov 2025-08-21 17:39:43 +00:00 committed by GitHub
commit c9deb57fa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1078 additions and 0 deletions

50
credentials/jwt/doc.go Normal file
View File

@ -0,0 +1,50 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package jwt implements JWT token file-based call credentials.
//
// This package provides support for A97 JWT Call Credentials, allowing gRPC
// clients to authenticate using JWT tokens read from files. While originally
// designed for xDS environments, these credentials are general-purpose.
//
// The credentials can be used directly in gRPC clients or configured via xDS.
//
// # Token Requirements
//
// JWT tokens must:
// - Be valid, well-formed JWT tokens with header, payload, and signature
// - Include an "exp" (expiration) claim
// - Be readable from the specified file path
//
// # Considerations
//
// - Tokens are cached until expiration to avoid excessive file I/O
// - Transport security is required (RequireTransportSecurity returns true)
// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or
// UNAUTHENTICATED errors. The errors are cached and retried with exponential
// backoff.
//
// This implementation is originally intended for use in service mesh
// environments like Istio where JWT tokens are provisioned and rotated by the
// infrastructure.
//
// # Experimental
//
// Notice: All APIs in this package are experimental and may be removed in a
// later release.
package jwt

View File

@ -0,0 +1,103 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package jwt
import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
"strings"
"time"
)
// jwtClaims represents the JWT claims structure for extracting expiration time.
type jwtClaims struct {
Exp int64 `json:"exp"`
}
// jWTFileReader handles reading and parsing JWT tokens from files.
type jWTFileReader struct {
tokenFilePath string
}
// newJWTFileReader creates a new JWTFileReader for the specified file path.
func newJWTFileReader(tokenFilePath string) *jWTFileReader {
return &jWTFileReader{
tokenFilePath: tokenFilePath,
}
}
// ReadToken reads and parses a JWT token from the configured file.
// Returns the token string, expiration time, and any error encountered.
func (r *jWTFileReader) ReadToken() (string, time.Time, error) {
tokenBytes, err := os.ReadFile(r.tokenFilePath)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to read token file %q: %v", r.tokenFilePath, err)
}
token := strings.TrimSpace(string(tokenBytes))
if token == "" {
return "", time.Time{}, fmt.Errorf("token file %q is empty", r.tokenFilePath)
}
exp, err := r.extractExpiration(token)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to parse JWT from token file %q: %v", r.tokenFilePath, err)
}
return token, exp, nil
}
// extractExpiration parses the JWT token to extract the expiration time.
func (r *jWTFileReader) extractExpiration(token string) (time.Time, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
payload := parts[1]
// Add padding if necessary for base64 decoding.
if m := len(payload) % 4; m != 0 {
payload += strings.Repeat("=", 4-m)
}
payloadBytes, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err)
}
var claims jwtClaims
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
}
if claims.Exp == 0 {
return time.Time{}, fmt.Errorf("JWT token has no expiration claim")
}
expTime := time.Unix(claims.Exp, 0)
// Check if token is already expired.
if expTime.Before(time.Now()) {
return time.Time{}, fmt.Errorf("JWT token is expired")
}
return expTime, nil
}

View File

@ -0,0 +1,180 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package jwt
import (
"encoding/base64"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) {
tests := []struct {
name string
setupFile func(string) error
wantErrContains string
}{
{
name: "nonexistent file",
setupFile: func(_ string) error {
return nil // Don't create the file
},
wantErrContains: "failed to read token file",
},
{
name: "empty file",
setupFile: func(path string) error {
return os.WriteFile(path, []byte(""), 0600)
},
wantErrContains: "token file",
},
{
name: "file with whitespace only",
setupFile: func(path string) error {
return os.WriteFile(path, []byte(" \n\t "), 0600)
},
wantErrContains: "token file",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
tokenFile := filepath.Join(tempDir, "token")
if err := tt.setupFile(tokenFile); err != nil {
t.Fatalf("Failed to setup test file: %v", err)
}
reader := newJWTFileReader(tokenFile)
_, _, err := reader.ReadToken()
if err == nil {
t.Fatal("ReadToken() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErrContains) {
t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains)
}
})
}
}
func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) {
now := time.Now().Truncate(time.Second)
tests := []struct {
name string
tokenContent string
wantErrContains string
}{
{
name: "valid token without expiration",
tokenContent: createTestJWT(t, "", time.Time{}),
wantErrContains: "JWT token has no expiration claim",
},
{
name: "expired token",
tokenContent: createTestJWT(t, "", now.Add(-time.Hour)),
wantErrContains: "JWT token is expired",
},
{
name: "malformed JWT - not enough parts",
tokenContent: "invalid.jwt",
wantErrContains: "invalid JWT format: expected 3 parts, got 2",
},
{
name: "malformed JWT - invalid base64",
tokenContent: "header.invalid_base64!@#.signature",
wantErrContains: "failed to decode JWT payload",
},
{
name: "malformed JWT - invalid JSON",
tokenContent: createInvalidJSONJWT(t),
wantErrContains: "failed to unmarshal JWT claims",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenFile := writeTempFile(t, "token", tt.tokenContent)
reader := newJWTFileReader(tokenFile)
_, _, err := reader.ReadToken()
if err == nil {
t.Fatal("ReadToken() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErrContains) {
t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains)
}
})
}
}
func TestJWTFileReader_ReadToken_ValidToken(t *testing.T) {
now := time.Now().Truncate(time.Second)
tokenExp := now.Add(time.Hour)
token := createTestJWT(t, "https://example.com", tokenExp)
tokenFile := writeTempFile(t, "token", token)
reader := newJWTFileReader(tokenFile)
readToken, expiry, err := reader.ReadToken()
if err != nil {
t.Fatalf("ReadToken() unexpected error: %v", err)
}
if readToken != token {
t.Errorf("ReadToken() token = %q, want %q", readToken, token)
}
if !expiry.Equal(tokenExp) {
t.Errorf("ReadToken() expiry = %v, want %v", expiry, tokenExp)
}
}
// createInvalidJSONJWT creates a JWT with invalid JSON in the payload.
func createInvalidJSONJWT(t *testing.T) string {
t.Helper()
header := map[string]any{
"typ": "JWT",
"alg": "HS256",
}
headerBytes, err := json.Marshal(header)
if err != nil {
t.Fatalf("Failed to marshal header: %v", err)
}
headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
headerB64 = strings.TrimRight(headerB64, "=")
// Create invalid JSON payload
invalidJSON := "invalid json content"
payloadB64 := base64.URLEncoding.EncodeToString([]byte(invalidJSON))
payloadB64 = strings.TrimRight(payloadB64, "=")
signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
signature = strings.TrimRight(signature, "=")
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature)
}

View File

@ -0,0 +1,181 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package jwt implements gRPC credentials using JWT tokens from files.
package jwt
import (
"context"
"fmt"
"strings"
"sync"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/backoff"
"google.golang.org/grpc/status"
)
const preemptiveRefreshThreshold = time.Minute
// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads
// tokens from a file.
// This implementation follows the A97 JWT Call Credentials specification.
type jwtTokenFileCallCreds struct {
fileReader *jWTFileReader
backoffStrategy backoff.Strategy
// cached data protected by mu
mu sync.Mutex
cachedAuthHeader string // "Bearer " + token
cachedExpiry time.Time // Slightly less than actual expiration time
cachedError error // Error from last failed attempt
retryAttempt int // Current retry attempt number
nextRetryTime time.Time // When next retry is allowed
pendingRefresh bool // Whether a refresh is currently in progress
}
// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens
// from the specified file path.
func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) {
if tokenFilePath == "" {
return nil, fmt.Errorf("tokenFilePath cannot be empty")
}
creds := &jwtTokenFileCallCreds{
fileReader: newJWTFileReader(tokenFilePath),
backoffStrategy: backoff.DefaultExponential,
}
return creds, nil
}
// GetRequestMetadata gets the current request metadata, refreshing tokens if
// required. This implementation follows the PerRPCCredentials interface. The
// tokens will get automatically refreshed if they are about to expire or if
// they haven't been loaded successfully yet.
// If it's not possible to extract a token from the file, UNAVAILABLE is
// returned.
// If the token is extracted but invalid, then UNAUTHENTICATED is returned.
// If errors are encoutered, a backoff is applied before retrying.
func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
ri, _ := credentials.RequestInfoFromContext(ctx)
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err)
}
c.mu.Lock()
defer c.mu.Unlock()
if c.isTokenValidLocked() {
if c.needsPreemptiveRefreshLocked() {
// Start refresh if not pending (handling the prior RPC may have
// just spawned a goroutine).
if !c.pendingRefresh {
c.pendingRefresh = true
go c.refreshToken()
}
}
return map[string]string{
"authorization": c.cachedAuthHeader,
}, nil
}
// If in backoff state, just return the cached error.
if c.cachedError != nil && time.Now().Before(c.nextRetryTime) {
return nil, c.cachedError
}
// At this point, the token is either invalid or expired and we are no
// longer backing off. So refresh it.
token, expiry, err := c.fileReader.ReadToken()
c.updateCacheLocked(token, expiry, err)
if c.cachedError != nil {
return nil, c.cachedError
}
return map[string]string{
"authorization": c.cachedAuthHeader,
}, nil
}
// RequireTransportSecurity indicates whether the credentials requires
// transport security.
func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool {
return true
}
// isTokenValidLocked checks if the cached token is still valid.
// Caller must hold c.mu lock.
func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool {
if c.cachedAuthHeader == "" {
return false
}
return c.cachedExpiry.After(time.Now())
}
// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be
// triggered.
// Returns true if the cached token is valid but expires within 1 minute.
// We only trigger pre-emptive refresh for valid tokens - if the token is
// invalid or expired, the next RPC will handle synchronous refresh instead.
// Caller must hold c.mu lock.
func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool {
return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < preemptiveRefreshThreshold
}
// refreshToken reads the token from file and updates the cached data.
func (c *jwtTokenFileCallCreds) refreshToken() {
// Deliberately not locking c.mu here
token, expiry, err := c.fileReader.ReadToken()
c.mu.Lock()
defer c.mu.Unlock()
c.updateCacheLocked(token, expiry, err)
c.pendingRefresh = false
}
// updateCacheLocked updates the cached token, expiry, and error state.
// If an error is provided, it determines whether to set it as an UNAVAILABLE
// or UNAUTHENTICATED error based on the error type.
// Caller must hold c.mu lock.
func (c *jwtTokenFileCallCreds) updateCacheLocked(token string, expiry time.Time, err error) {
if err != nil {
// Convert to gRPC status codes
if strings.Contains(err.Error(), "failed to read token file") || strings.Contains(err.Error(), "token file") && strings.Contains(err.Error(), "is empty") {
c.cachedError = status.Errorf(codes.Unavailable, "%v", err)
} else {
c.cachedError = status.Errorf(codes.Unauthenticated, "%v", err)
}
c.retryAttempt++
backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1)
c.nextRetryTime = time.Now().Add(backoffDelay)
} else {
// Success - clear any cached error and update token cache
c.cachedError = nil
c.retryAttempt = 0
c.nextRetryTime = time.Time{}
c.cachedAuthHeader = "Bearer " + token
// Per RFC A97: consider token invalid if it expires within the next 30
// seconds to accommodate for clock skew and server processing time.
c.cachedExpiry = expiry.Add(-30 * time.Second)
}
}

View File

@ -0,0 +1,564 @@
/*
*
* Copyright 2025 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package jwt
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/status"
)
const defaultTestTimeout = 5 * time.Second
type s struct {
grpctest.Tester
}
func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
func (s) TestNewTokenFileCallCredentials(t *testing.T) {
tests := []struct {
name string
tokenFilePath string
wantErr string
}{
{
name: "some filepath",
tokenFilePath: "/path/to/token",
wantErr: "",
},
{
name: "empty filepath",
tokenFilePath: "",
wantErr: "tokenFilePath cannot be empty",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
creds, err := NewTokenFileCallCredentials(tt.tokenFilePath)
if tt.wantErr != "" {
if err == nil {
t.Fatalf("NewTokenFileCallCredentials() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err)
}
if creds == nil {
t.Fatal("NewTokenFileCallCredentials() returned nil credentials")
}
})
}
}
func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) {
creds, err := NewTokenFileCallCredentials("/path/to/token")
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
}
if !creds.RequireTransportSecurity() {
t.Error("RequireTransportSecurity() = false, want true")
}
}
func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) {
tempDir, err := os.MkdirTemp("", "jwt_test")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tempDir)
now := time.Now().Truncate(time.Second)
tests := []struct {
name string
tokenContent string
authInfo credentials.AuthInfo
wantErr bool
wantErrContains string
wantMetadata map[string]string
}{
{
name: "valid token with future expiration",
tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)),
authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
wantErr: false,
wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, "https://example.com", now.Add(time.Hour))},
},
{
name: "insufficient security level",
tokenContent: createTestJWT(t, "", now.Add(time.Hour)),
authInfo: &testAuthInfo{secLevel: credentials.NoSecurity},
wantErr: true,
wantErrContains: "unable to transfer JWT token file PerRPCCredentials",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenFile := writeTempFile(t, "token", tt.tokenContent)
creds, err := NewTokenFileCallCredentials(tokenFile)
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
AuthInfo: tt.authInfo,
})
metadata, err := creds.GetRequestMetadata(ctx)
if tt.wantErr {
if err == nil {
t.Fatalf("GetRequestMetadata() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErrContains) {
t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Fatalf("GetRequestMetadata() unexpected error: %v", err)
}
if len(metadata) != len(tt.wantMetadata) {
t.Fatalf("GetRequestMetadata() returned %d metadata entries, want %d", len(metadata), len(tt.wantMetadata))
}
for k, v := range tt.wantMetadata {
if metadata[k] != v {
t.Errorf("GetRequestMetadata() metadata[%q] = %q, want %q", k, metadata[k], v)
}
}
})
}
}
func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) {
token := createTestJWT(t, "", time.Now().Add(time.Hour))
tokenFile := writeTempFile(t, "token", token)
creds, err := NewTokenFileCallCredentials(tokenFile)
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
})
// First call should read from file.
metadata1, err := creds.GetRequestMetadata(ctx)
if err != nil {
t.Fatalf("First GetRequestMetadata() failed: %v", err)
}
// Update the file with a different token.
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
t.Fatalf("Failed to update token file: %v", err)
}
// Second call should return cached token (not the updated one).
metadata2, err := creds.GetRequestMetadata(ctx)
if err != nil {
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
}
if metadata1["authorization"] != metadata2["authorization"] {
t.Error("Expected cached token to be returned, but got different token")
}
}
// testAuthInfo implements credentials.AuthInfo for testing.
type testAuthInfo struct {
secLevel credentials.SecurityLevel
}
func (t *testAuthInfo) AuthType() string {
return "test"
}
func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo {
return credentials.CommonAuthInfo{SecurityLevel: t.secLevel}
}
// Tests that cached token expiration is set to 30 seconds before actual token
// expiration.
func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) {
// Create token that expires in 2 hours.
tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour)
token := createTestJWT(t, "", tokenExp)
tokenFile := writeTempFile(t, "token", token)
creds, err := NewTokenFileCallCredentials(tokenFile)
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
})
// Get token to trigger caching.
_, err = creds.GetRequestMetadata(ctx)
if err != nil {
t.Fatalf("GetRequestMetadata() failed: %v", err)
}
// Verify cached expiration is 30 seconds before actual token expiration.
impl := creds.(*jwtTokenFileCallCreds)
impl.mu.Lock()
cachedExp := impl.cachedExpiry
impl.mu.Unlock()
expectedExp := tokenExp.Add(-30 * time.Second)
if !cachedExp.Equal(expectedExp) {
t.Errorf("cache expiration = %v, want %v", cachedExp, expectedExp)
}
}
// Tests that pre-emptive refresh is triggered within 1 minute of expiration.
func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) {
// Create token that expires in 80 seconds (=> cache expires in ~50s).
// This ensures pre-emptive refresh triggers since 50s < the 1 minute check.
tokenExp := time.Now().Add(80 * time.Second)
expiringToken := createTestJWT(t, "", tokenExp)
tokenFile := writeTempFile(t, "token", expiringToken)
creds, err := NewTokenFileCallCredentials(tokenFile)
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
})
// Get token - should trigger pre-emptive refresh.
metadata1, err := creds.GetRequestMetadata(ctx)
if err != nil {
t.Fatalf("GetRequestMetadata() failed: %v", err)
}
// Verify token was cached and check if refresh should be triggered.
impl := creds.(*jwtTokenFileCallCreds)
impl.mu.Lock()
cacheExp := impl.cachedExpiry
tokenCached := impl.cachedAuthHeader != ""
shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked()
impl.mu.Unlock()
if !tokenCached {
t.Error("token should be cached after successful GetRequestMetadata")
}
if !shouldTriggerRefresh {
timeUntilExp := time.Until(cacheExp)
t.Errorf("cache expires in %v, should be < 1 minute to trigger pre-emptive refresh", timeUntilExp)
}
// Create new token file with different expiration while refresh is
// happening.
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
t.Fatalf("Failed to write updated token file: %v", err)
}
// Get token again - should trigger a refresh given that the first one was
// cached but expiring soon.
// However, the function should have returned right away with the current
// cached token.
metadata2, err := creds.GetRequestMetadata(ctx)
if err != nil {
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Now should get the new token.
metadata3, err := creds.GetRequestMetadata(ctx)
if err != nil {
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
}
// If pre-emptive refresh worked, we should get the new token.
expectedAuth1 := "Bearer " + expiringToken
expectedAuth2 := "Bearer " + expiringToken
expectedAuth3 := "Bearer " + newToken
actualAuth1 := metadata1["authorization"]
actualAuth2 := metadata2["authorization"]
actualAuth3 := metadata3["authorization"]
if actualAuth1 != expectedAuth1 {
t.Errorf("First call should return original token: got %q, want %q", actualAuth1, expectedAuth1)
}
if actualAuth2 != expectedAuth2 {
t.Errorf("Second call should return the original token: got %q, want %q", actualAuth2, expectedAuth2)
}
if actualAuth3 != expectedAuth3 {
t.Errorf("Third call should return the new token: got %q, want %q", actualAuth3, expectedAuth3)
}
}
// Tests that backoff behavior handles file read errors correctly.
func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) {
// This test has the following expectations:
// First call to GetRequestMetadata() fails with UNAVAILABLE due to a
// missing file.
// Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff.
// Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry.
// Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff
// even though file exists.
// Fifth call to GetRequestMetadata() succeeds after reading the file and
// backoff has expired.
tempDir := t.TempDir()
nonExistentFile := filepath.Join(tempDir, "nonexistent")
creds, err := NewTokenFileCallCredentials(nonExistentFile)
if err != nil {
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
})
// First call should fail with UNAVAILABLE.
_, err1 := creds.GetRequestMetadata(ctx)
if err1 == nil {
t.Fatal("Expected error from nonexistent file")
}
if status.Code(err1) != codes.Unavailable {
t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err1))
}
// Verify error is cached internally.
impl := creds.(*jwtTokenFileCallCreds)
impl.mu.Lock()
cachedErr := impl.cachedError
retryAttempt := impl.retryAttempt
nextRetryTime := impl.nextRetryTime
impl.mu.Unlock()
if cachedErr == nil {
t.Error("error should be cached internally after failed file read")
}
if retryAttempt != 1 {
t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt)
}
if nextRetryTime.IsZero() || nextRetryTime.Before(time.Now()) {
t.Error("Next retry time should be set to future time")
}
// Second call should still return cached error.
_, err2 := creds.GetRequestMetadata(ctx)
if err2 == nil {
t.Fatal("Expected cached error")
}
if status.Code(err2) != codes.Unavailable {
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err2))
}
if err1.Error() != err2.Error() {
t.Errorf("cached error = %q, want %q", err2.Error(), err1.Error())
}
impl.mu.Lock()
retryAttempt2 := impl.retryAttempt
nextRetryTime2 := impl.nextRetryTime
impl.mu.Unlock()
if !nextRetryTime2.Equal(nextRetryTime) {
t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, nextRetryTime)
}
if retryAttempt2 != 1 {
t.Error("retry attempt should not change due to backoff")
}
// Fast-forward the backoff retry time to allow next retry attempt.
impl.mu.Lock()
impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
impl.mu.Unlock()
// Third call should retry but still fail with UNAVAILABLE.
_, err3 := creds.GetRequestMetadata(ctx)
if err3 == nil {
t.Fatal("Expected cached error")
}
if status.Code(err3) != codes.Unavailable {
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err3))
}
if err3.Error() != err1.Error() {
t.Errorf("cached error = %q, want %q", err3.Error(), err1.Error())
}
impl.mu.Lock()
retryAttempt3 := impl.retryAttempt
nextRetryTime3 := impl.nextRetryTime
impl.mu.Unlock()
if !nextRetryTime3.After(nextRetryTime2) {
t.Error("nextRetryTime should not change due to backoff")
}
if retryAttempt3 != 2 {
t.Error("retry attempt should not change due to backoff")
}
// Create valid token file.
validToken := createTestJWT(t, "", time.Now().Add(time.Hour))
if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil {
t.Fatalf("Failed to create valid token file: %v", err)
}
// Fourth call should still fail even though the file now exists.
_, err4 := creds.GetRequestMetadata(ctx)
if err4 == nil {
t.Fatal("Expected cached error")
}
if status.Code(err4) != codes.Unavailable {
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err4))
}
if err4.Error() != err3.Error() {
t.Errorf("cached error = %q, want %q", err4.Error(), err3.Error())
}
impl.mu.Lock()
retryAttempt4 := impl.retryAttempt
nextRetryTime4 := impl.nextRetryTime
impl.mu.Unlock()
if !nextRetryTime4.Equal(nextRetryTime3) {
t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, nextRetryTime3)
}
if retryAttempt4 != retryAttempt3 {
t.Error("retry attempt should not change due to backoff")
}
// Fast-forward the backoff retry time to allow next retry attempt.
impl.mu.Lock()
impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
impl.mu.Unlock()
// Fifth call should succeed since the file now exists
// and the backoff has expired.
_, err5 := creds.GetRequestMetadata(ctx)
if err5 != nil {
t.Errorf("after creating valid token file, GetRequestMetadata() should eventually succeed, but got: %v", err5)
t.Error("backoff should expire and trigger new attempt on next RPC")
} else {
// If successful, verify error cache and backoff state were cleared.
impl.mu.Lock()
clearedErr := impl.cachedError
retryAttempt := impl.retryAttempt
nextRetryTime := impl.nextRetryTime
impl.mu.Unlock()
if clearedErr != nil {
t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr)
}
if retryAttempt != 0 {
t.Errorf("after successful retry, retry attempt should be reset, got: %d", retryAttempt)
}
if !nextRetryTime.IsZero() {
t.Error("after successful retry, next retry time should be cleared")
}
}
}
// createTestJWT creates a test JWT token with the specified audience and
// expiration.
func createTestJWT(t *testing.T, audience string, expiration time.Time) string {
t.Helper()
header := map[string]any{
"typ": "JWT",
"alg": "HS256",
}
claims := map[string]any{}
if audience != "" {
claims["aud"] = audience
}
if !expiration.IsZero() {
claims["exp"] = expiration.Unix()
}
headerBytes, err := json.Marshal(header)
if err != nil {
t.Fatalf("Failed to marshal header: %v", err)
}
claimsBytes, err := json.Marshal(claims)
if err != nil {
t.Fatalf("Failed to marshal claims: %v", err)
}
headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes)
// Remove padding for URL-safe base64
headerB64 = strings.TrimRight(headerB64, "=")
claimsB64 = strings.TrimRight(claimsB64, "=")
// For testing, we'll use a fake signature
signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
signature = strings.TrimRight(signature, "=")
return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature)
}
func writeTempFile(t *testing.T, name, content string) string {
t.Helper()
tempDir := t.TempDir()
filePath := filepath.Join(tempDir, name)
if err := os.WriteFile(filePath, []byte(content), 0600); err != nil {
t.Fatalf("Failed to write temp file: %v", err)
}
return filePath
}