mirror of https://github.com/grpc/grpc-go.git
Merge 12fedd5964
into 01ae4f4c48
This commit is contained in:
commit
c9deb57fa4
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2025 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package jwt implements JWT token file-based call credentials.
|
||||
//
|
||||
// This package provides support for A97 JWT Call Credentials, allowing gRPC
|
||||
// clients to authenticate using JWT tokens read from files. While originally
|
||||
// designed for xDS environments, these credentials are general-purpose.
|
||||
//
|
||||
// The credentials can be used directly in gRPC clients or configured via xDS.
|
||||
//
|
||||
// # Token Requirements
|
||||
//
|
||||
// JWT tokens must:
|
||||
// - Be valid, well-formed JWT tokens with header, payload, and signature
|
||||
// - Include an "exp" (expiration) claim
|
||||
// - Be readable from the specified file path
|
||||
//
|
||||
// # Considerations
|
||||
//
|
||||
// - Tokens are cached until expiration to avoid excessive file I/O
|
||||
// - Transport security is required (RequireTransportSecurity returns true)
|
||||
// - Errors in reading tokens or parsing JWTs will result in RPC UNAVAILALBE or
|
||||
// UNAUTHENTICATED errors. The errors are cached and retried with exponential
|
||||
// backoff.
|
||||
//
|
||||
// This implementation is originally intended for use in service mesh
|
||||
// environments like Istio where JWT tokens are provisioned and rotated by the
|
||||
// infrastructure.
|
||||
//
|
||||
// # Experimental
|
||||
//
|
||||
// Notice: All APIs in this package are experimental and may be removed in a
|
||||
// later release.
|
||||
package jwt
|
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2025 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// jwtClaims represents the JWT claims structure for extracting expiration time.
|
||||
type jwtClaims struct {
|
||||
Exp int64 `json:"exp"`
|
||||
}
|
||||
|
||||
// jWTFileReader handles reading and parsing JWT tokens from files.
|
||||
type jWTFileReader struct {
|
||||
tokenFilePath string
|
||||
}
|
||||
|
||||
// newJWTFileReader creates a new JWTFileReader for the specified file path.
|
||||
func newJWTFileReader(tokenFilePath string) *jWTFileReader {
|
||||
return &jWTFileReader{
|
||||
tokenFilePath: tokenFilePath,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadToken reads and parses a JWT token from the configured file.
|
||||
// Returns the token string, expiration time, and any error encountered.
|
||||
func (r *jWTFileReader) ReadToken() (string, time.Time, error) {
|
||||
tokenBytes, err := os.ReadFile(r.tokenFilePath)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("failed to read token file %q: %v", r.tokenFilePath, err)
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(string(tokenBytes))
|
||||
if token == "" {
|
||||
return "", time.Time{}, fmt.Errorf("token file %q is empty", r.tokenFilePath)
|
||||
}
|
||||
|
||||
exp, err := r.extractExpiration(token)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("failed to parse JWT from token file %q: %v", r.tokenFilePath, err)
|
||||
}
|
||||
|
||||
return token, exp, nil
|
||||
}
|
||||
|
||||
// extractExpiration parses the JWT token to extract the expiration time.
|
||||
func (r *jWTFileReader) extractExpiration(token string) (time.Time, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
// Add padding if necessary for base64 decoding.
|
||||
if m := len(payload) % 4; m != 0 {
|
||||
payload += strings.Repeat("=", 4-m)
|
||||
}
|
||||
|
||||
payloadBytes, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err)
|
||||
}
|
||||
|
||||
var claims jwtClaims
|
||||
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
|
||||
}
|
||||
|
||||
if claims.Exp == 0 {
|
||||
return time.Time{}, fmt.Errorf("JWT token has no expiration claim")
|
||||
}
|
||||
|
||||
expTime := time.Unix(claims.Exp, 0)
|
||||
|
||||
// Check if token is already expired.
|
||||
if expTime.Before(time.Now()) {
|
||||
return time.Time{}, fmt.Errorf("JWT token is expired")
|
||||
}
|
||||
|
||||
return expTime, nil
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2025 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestJWTFileReader_ReadToken_FileErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFile func(string) error
|
||||
wantErrContains string
|
||||
}{
|
||||
{
|
||||
name: "nonexistent file",
|
||||
setupFile: func(_ string) error {
|
||||
return nil // Don't create the file
|
||||
},
|
||||
wantErrContains: "failed to read token file",
|
||||
},
|
||||
{
|
||||
name: "empty file",
|
||||
setupFile: func(path string) error {
|
||||
return os.WriteFile(path, []byte(""), 0600)
|
||||
},
|
||||
wantErrContains: "token file",
|
||||
},
|
||||
{
|
||||
name: "file with whitespace only",
|
||||
setupFile: func(path string) error {
|
||||
return os.WriteFile(path, []byte(" \n\t "), 0600)
|
||||
},
|
||||
wantErrContains: "token file",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
tokenFile := filepath.Join(tempDir, "token")
|
||||
if err := tt.setupFile(tokenFile); err != nil {
|
||||
t.Fatalf("Failed to setup test file: %v", err)
|
||||
}
|
||||
|
||||
reader := newJWTFileReader(tokenFile)
|
||||
_, _, err := reader.ReadToken()
|
||||
if err == nil {
|
||||
t.Fatal("ReadToken() expected error, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||
t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTFileReader_ReadToken_InvalidJWT(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenContent string
|
||||
wantErrContains string
|
||||
}{
|
||||
{
|
||||
name: "valid token without expiration",
|
||||
tokenContent: createTestJWT(t, "", time.Time{}),
|
||||
wantErrContains: "JWT token has no expiration claim",
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
tokenContent: createTestJWT(t, "", now.Add(-time.Hour)),
|
||||
wantErrContains: "JWT token is expired",
|
||||
},
|
||||
{
|
||||
name: "malformed JWT - not enough parts",
|
||||
tokenContent: "invalid.jwt",
|
||||
wantErrContains: "invalid JWT format: expected 3 parts, got 2",
|
||||
},
|
||||
{
|
||||
name: "malformed JWT - invalid base64",
|
||||
tokenContent: "header.invalid_base64!@#.signature",
|
||||
wantErrContains: "failed to decode JWT payload",
|
||||
},
|
||||
{
|
||||
name: "malformed JWT - invalid JSON",
|
||||
tokenContent: createInvalidJSONJWT(t),
|
||||
wantErrContains: "failed to unmarshal JWT claims",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenFile := writeTempFile(t, "token", tt.tokenContent)
|
||||
|
||||
reader := newJWTFileReader(tokenFile)
|
||||
_, _, err := reader.ReadToken()
|
||||
if err == nil {
|
||||
t.Fatal("ReadToken() expected error, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||
t.Fatalf("ReadToken() error = %v, want error containing %q", err, tt.wantErrContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTFileReader_ReadToken_ValidToken(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
tokenExp := now.Add(time.Hour)
|
||||
token := createTestJWT(t, "https://example.com", tokenExp)
|
||||
tokenFile := writeTempFile(t, "token", token)
|
||||
|
||||
reader := newJWTFileReader(tokenFile)
|
||||
readToken, expiry, err := reader.ReadToken()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadToken() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if readToken != token {
|
||||
t.Errorf("ReadToken() token = %q, want %q", readToken, token)
|
||||
}
|
||||
|
||||
if !expiry.Equal(tokenExp) {
|
||||
t.Errorf("ReadToken() expiry = %v, want %v", expiry, tokenExp)
|
||||
}
|
||||
}
|
||||
|
||||
// createInvalidJSONJWT creates a JWT with invalid JSON in the payload.
|
||||
func createInvalidJSONJWT(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
header := map[string]any{
|
||||
"typ": "JWT",
|
||||
"alg": "HS256",
|
||||
}
|
||||
|
||||
headerBytes, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
|
||||
headerB64 = strings.TrimRight(headerB64, "=")
|
||||
|
||||
// Create invalid JSON payload
|
||||
invalidJSON := "invalid json content"
|
||||
payloadB64 := base64.URLEncoding.EncodeToString([]byte(invalidJSON))
|
||||
payloadB64 = strings.TrimRight(payloadB64, "=")
|
||||
|
||||
signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
|
||||
signature = strings.TrimRight(signature, "=")
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signature)
|
||||
}
|
|
@ -0,0 +1,181 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2025 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
// Package jwt implements gRPC credentials using JWT tokens from files.
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal/backoff"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const preemptiveRefreshThreshold = time.Minute
|
||||
|
||||
// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads
|
||||
// tokens from a file.
|
||||
// This implementation follows the A97 JWT Call Credentials specification.
|
||||
type jwtTokenFileCallCreds struct {
|
||||
fileReader *jWTFileReader
|
||||
backoffStrategy backoff.Strategy
|
||||
|
||||
// cached data protected by mu
|
||||
mu sync.Mutex
|
||||
cachedAuthHeader string // "Bearer " + token
|
||||
cachedExpiry time.Time // Slightly less than actual expiration time
|
||||
cachedError error // Error from last failed attempt
|
||||
retryAttempt int // Current retry attempt number
|
||||
nextRetryTime time.Time // When next retry is allowed
|
||||
pendingRefresh bool // Whether a refresh is currently in progress
|
||||
}
|
||||
|
||||
// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens
|
||||
// from the specified file path.
|
||||
func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) {
|
||||
if tokenFilePath == "" {
|
||||
return nil, fmt.Errorf("tokenFilePath cannot be empty")
|
||||
}
|
||||
|
||||
creds := &jwtTokenFileCallCreds{
|
||||
fileReader: newJWTFileReader(tokenFilePath),
|
||||
backoffStrategy: backoff.DefaultExponential,
|
||||
}
|
||||
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
// GetRequestMetadata gets the current request metadata, refreshing tokens if
|
||||
// required. This implementation follows the PerRPCCredentials interface. The
|
||||
// tokens will get automatically refreshed if they are about to expire or if
|
||||
// they haven't been loaded successfully yet.
|
||||
// If it's not possible to extract a token from the file, UNAVAILABLE is
|
||||
// returned.
|
||||
// If the token is extracted but invalid, then UNAUTHENTICATED is returned.
|
||||
// If errors are encoutered, a backoff is applied before retrying.
|
||||
func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
|
||||
ri, _ := credentials.RequestInfoFromContext(ctx)
|
||||
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
|
||||
return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.isTokenValidLocked() {
|
||||
if c.needsPreemptiveRefreshLocked() {
|
||||
// Start refresh if not pending (handling the prior RPC may have
|
||||
// just spawned a goroutine).
|
||||
if !c.pendingRefresh {
|
||||
c.pendingRefresh = true
|
||||
go c.refreshToken()
|
||||
}
|
||||
}
|
||||
return map[string]string{
|
||||
"authorization": c.cachedAuthHeader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// If in backoff state, just return the cached error.
|
||||
if c.cachedError != nil && time.Now().Before(c.nextRetryTime) {
|
||||
return nil, c.cachedError
|
||||
}
|
||||
|
||||
// At this point, the token is either invalid or expired and we are no
|
||||
// longer backing off. So refresh it.
|
||||
token, expiry, err := c.fileReader.ReadToken()
|
||||
c.updateCacheLocked(token, expiry, err)
|
||||
|
||||
if c.cachedError != nil {
|
||||
return nil, c.cachedError
|
||||
}
|
||||
return map[string]string{
|
||||
"authorization": c.cachedAuthHeader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequireTransportSecurity indicates whether the credentials requires
|
||||
// transport security.
|
||||
func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// isTokenValidLocked checks if the cached token is still valid.
|
||||
// Caller must hold c.mu lock.
|
||||
func (c *jwtTokenFileCallCreds) isTokenValidLocked() bool {
|
||||
if c.cachedAuthHeader == "" {
|
||||
return false
|
||||
}
|
||||
return c.cachedExpiry.After(time.Now())
|
||||
}
|
||||
|
||||
// needsPreemptiveRefreshLocked checks if a pre-emptive refresh should be
|
||||
// triggered.
|
||||
// Returns true if the cached token is valid but expires within 1 minute.
|
||||
// We only trigger pre-emptive refresh for valid tokens - if the token is
|
||||
// invalid or expired, the next RPC will handle synchronous refresh instead.
|
||||
// Caller must hold c.mu lock.
|
||||
func (c *jwtTokenFileCallCreds) needsPreemptiveRefreshLocked() bool {
|
||||
return c.isTokenValidLocked() && time.Until(c.cachedExpiry) < preemptiveRefreshThreshold
|
||||
}
|
||||
|
||||
// refreshToken reads the token from file and updates the cached data.
|
||||
func (c *jwtTokenFileCallCreds) refreshToken() {
|
||||
// Deliberately not locking c.mu here
|
||||
token, expiry, err := c.fileReader.ReadToken()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.updateCacheLocked(token, expiry, err)
|
||||
|
||||
c.pendingRefresh = false
|
||||
}
|
||||
|
||||
// updateCacheLocked updates the cached token, expiry, and error state.
|
||||
// If an error is provided, it determines whether to set it as an UNAVAILABLE
|
||||
// or UNAUTHENTICATED error based on the error type.
|
||||
// Caller must hold c.mu lock.
|
||||
func (c *jwtTokenFileCallCreds) updateCacheLocked(token string, expiry time.Time, err error) {
|
||||
if err != nil {
|
||||
// Convert to gRPC status codes
|
||||
if strings.Contains(err.Error(), "failed to read token file") || strings.Contains(err.Error(), "token file") && strings.Contains(err.Error(), "is empty") {
|
||||
c.cachedError = status.Errorf(codes.Unavailable, "%v", err)
|
||||
} else {
|
||||
c.cachedError = status.Errorf(codes.Unauthenticated, "%v", err)
|
||||
}
|
||||
c.retryAttempt++
|
||||
backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1)
|
||||
c.nextRetryTime = time.Now().Add(backoffDelay)
|
||||
} else {
|
||||
// Success - clear any cached error and update token cache
|
||||
c.cachedError = nil
|
||||
c.retryAttempt = 0
|
||||
c.nextRetryTime = time.Time{}
|
||||
|
||||
c.cachedAuthHeader = "Bearer " + token
|
||||
// Per RFC A97: consider token invalid if it expires within the next 30
|
||||
// seconds to accommodate for clock skew and server processing time.
|
||||
c.cachedExpiry = expiry.Add(-30 * time.Second)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,564 @@
|
|||
/*
|
||||
*
|
||||
* Copyright 2025 gRPC authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/internal/grpctest"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const defaultTestTimeout = 5 * time.Second
|
||||
|
||||
type s struct {
|
||||
grpctest.Tester
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
grpctest.RunSubTests(t, s{})
|
||||
}
|
||||
|
||||
func (s) TestNewTokenFileCallCredentials(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenFilePath string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "some filepath",
|
||||
tokenFilePath: "/path/to/token",
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty filepath",
|
||||
tokenFilePath: "",
|
||||
wantErr: "tokenFilePath cannot be empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
creds, err := NewTokenFileCallCredentials(tt.tokenFilePath)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err)
|
||||
}
|
||||
if creds == nil {
|
||||
t.Fatal("NewTokenFileCallCredentials() returned nil credentials")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) {
|
||||
creds, err := NewTokenFileCallCredentials("/path/to/token")
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
||||
}
|
||||
|
||||
if !creds.RequireTransportSecurity() {
|
||||
t.Error("RequireTransportSecurity() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "jwt_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
now := time.Now().Truncate(time.Second)
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenContent string
|
||||
authInfo credentials.AuthInfo
|
||||
wantErr bool
|
||||
wantErrContains string
|
||||
wantMetadata map[string]string
|
||||
}{
|
||||
{
|
||||
name: "valid token with future expiration",
|
||||
tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)),
|
||||
authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
||||
wantErr: false,
|
||||
wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, "https://example.com", now.Add(time.Hour))},
|
||||
},
|
||||
{
|
||||
name: "insufficient security level",
|
||||
tokenContent: createTestJWT(t, "", now.Add(time.Hour)),
|
||||
authInfo: &testAuthInfo{secLevel: credentials.NoSecurity},
|
||||
wantErr: true,
|
||||
wantErrContains: "unable to transfer JWT token file PerRPCCredentials",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenFile := writeTempFile(t, "token", tt.tokenContent)
|
||||
|
||||
creds, err := NewTokenFileCallCredentials(tokenFile)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
||||
AuthInfo: tt.authInfo,
|
||||
})
|
||||
|
||||
metadata, err := creds.GetRequestMetadata(ctx)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("GetRequestMetadata() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||
t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("GetRequestMetadata() unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(metadata) != len(tt.wantMetadata) {
|
||||
t.Fatalf("GetRequestMetadata() returned %d metadata entries, want %d", len(metadata), len(tt.wantMetadata))
|
||||
}
|
||||
|
||||
for k, v := range tt.wantMetadata {
|
||||
if metadata[k] != v {
|
||||
t.Errorf("GetRequestMetadata() metadata[%q] = %q, want %q", k, metadata[k], v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) {
|
||||
token := createTestJWT(t, "", time.Now().Add(time.Hour))
|
||||
tokenFile := writeTempFile(t, "token", token)
|
||||
|
||||
creds, err := NewTokenFileCallCredentials(tokenFile)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
||||
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
||||
})
|
||||
|
||||
// First call should read from file.
|
||||
metadata1, err := creds.GetRequestMetadata(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("First GetRequestMetadata() failed: %v", err)
|
||||
}
|
||||
|
||||
// Update the file with a different token.
|
||||
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
|
||||
if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
|
||||
t.Fatalf("Failed to update token file: %v", err)
|
||||
}
|
||||
|
||||
// Second call should return cached token (not the updated one).
|
||||
metadata2, err := creds.GetRequestMetadata(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
|
||||
}
|
||||
|
||||
if metadata1["authorization"] != metadata2["authorization"] {
|
||||
t.Error("Expected cached token to be returned, but got different token")
|
||||
}
|
||||
}
|
||||
|
||||
// testAuthInfo implements credentials.AuthInfo for testing.
|
||||
type testAuthInfo struct {
|
||||
secLevel credentials.SecurityLevel
|
||||
}
|
||||
|
||||
func (t *testAuthInfo) AuthType() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo {
|
||||
return credentials.CommonAuthInfo{SecurityLevel: t.secLevel}
|
||||
}
|
||||
|
||||
// Tests that cached token expiration is set to 30 seconds before actual token
|
||||
// expiration.
|
||||
func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) {
|
||||
// Create token that expires in 2 hours.
|
||||
tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour)
|
||||
token := createTestJWT(t, "", tokenExp)
|
||||
tokenFile := writeTempFile(t, "token", token)
|
||||
|
||||
creds, err := NewTokenFileCallCredentials(tokenFile)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
||||
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
||||
})
|
||||
|
||||
// Get token to trigger caching.
|
||||
_, err = creds.GetRequestMetadata(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRequestMetadata() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify cached expiration is 30 seconds before actual token expiration.
|
||||
impl := creds.(*jwtTokenFileCallCreds)
|
||||
impl.mu.Lock()
|
||||
cachedExp := impl.cachedExpiry
|
||||
impl.mu.Unlock()
|
||||
|
||||
expectedExp := tokenExp.Add(-30 * time.Second)
|
||||
if !cachedExp.Equal(expectedExp) {
|
||||
t.Errorf("cache expiration = %v, want %v", cachedExp, expectedExp)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that pre-emptive refresh is triggered within 1 minute of expiration.
|
||||
func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) {
|
||||
// Create token that expires in 80 seconds (=> cache expires in ~50s).
|
||||
// This ensures pre-emptive refresh triggers since 50s < the 1 minute check.
|
||||
tokenExp := time.Now().Add(80 * time.Second)
|
||||
expiringToken := createTestJWT(t, "", tokenExp)
|
||||
tokenFile := writeTempFile(t, "token", expiringToken)
|
||||
|
||||
creds, err := NewTokenFileCallCredentials(tokenFile)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
||||
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
||||
})
|
||||
|
||||
// Get token - should trigger pre-emptive refresh.
|
||||
metadata1, err := creds.GetRequestMetadata(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRequestMetadata() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify token was cached and check if refresh should be triggered.
|
||||
impl := creds.(*jwtTokenFileCallCreds)
|
||||
impl.mu.Lock()
|
||||
cacheExp := impl.cachedExpiry
|
||||
tokenCached := impl.cachedAuthHeader != ""
|
||||
shouldTriggerRefresh := impl.needsPreemptiveRefreshLocked()
|
||||
impl.mu.Unlock()
|
||||
|
||||
if !tokenCached {
|
||||
t.Error("token should be cached after successful GetRequestMetadata")
|
||||
}
|
||||
|
||||
if !shouldTriggerRefresh {
|
||||
timeUntilExp := time.Until(cacheExp)
|
||||
t.Errorf("cache expires in %v, should be < 1 minute to trigger pre-emptive refresh", timeUntilExp)
|
||||
}
|
||||
|
||||
// Create new token file with different expiration while refresh is
|
||||
// happening.
|
||||
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
|
||||
if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
|
||||
t.Fatalf("Failed to write updated token file: %v", err)
|
||||
}
|
||||
|
||||
// Get token again - should trigger a refresh given that the first one was
|
||||
// cached but expiring soon.
|
||||
// However, the function should have returned right away with the current
|
||||
// cached token.
|
||||
metadata2, err := creds.GetRequestMetadata(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Now should get the new token.
|
||||
metadata3, err := creds.GetRequestMetadata(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
|
||||
}
|
||||
|
||||
// If pre-emptive refresh worked, we should get the new token.
|
||||
expectedAuth1 := "Bearer " + expiringToken
|
||||
expectedAuth2 := "Bearer " + expiringToken
|
||||
expectedAuth3 := "Bearer " + newToken
|
||||
|
||||
actualAuth1 := metadata1["authorization"]
|
||||
actualAuth2 := metadata2["authorization"]
|
||||
actualAuth3 := metadata3["authorization"]
|
||||
|
||||
if actualAuth1 != expectedAuth1 {
|
||||
t.Errorf("First call should return original token: got %q, want %q", actualAuth1, expectedAuth1)
|
||||
}
|
||||
|
||||
if actualAuth2 != expectedAuth2 {
|
||||
t.Errorf("Second call should return the original token: got %q, want %q", actualAuth2, expectedAuth2)
|
||||
}
|
||||
if actualAuth3 != expectedAuth3 {
|
||||
t.Errorf("Third call should return the new token: got %q, want %q", actualAuth3, expectedAuth3)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that backoff behavior handles file read errors correctly.
|
||||
func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) {
|
||||
// This test has the following expectations:
|
||||
// First call to GetRequestMetadata() fails with UNAVAILABLE due to a
|
||||
// missing file.
|
||||
// Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff.
|
||||
// Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry.
|
||||
// Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff
|
||||
// even though file exists.
|
||||
// Fifth call to GetRequestMetadata() succeeds after reading the file and
|
||||
// backoff has expired.
|
||||
tempDir := t.TempDir()
|
||||
nonExistentFile := filepath.Join(tempDir, "nonexistent")
|
||||
|
||||
creds, err := NewTokenFileCallCredentials(nonExistentFile)
|
||||
if err != nil {
|
||||
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
|
||||
defer cancel()
|
||||
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
||||
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
||||
})
|
||||
|
||||
// First call should fail with UNAVAILABLE.
|
||||
_, err1 := creds.GetRequestMetadata(ctx)
|
||||
if err1 == nil {
|
||||
t.Fatal("Expected error from nonexistent file")
|
||||
}
|
||||
if status.Code(err1) != codes.Unavailable {
|
||||
t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err1))
|
||||
}
|
||||
|
||||
// Verify error is cached internally.
|
||||
impl := creds.(*jwtTokenFileCallCreds)
|
||||
impl.mu.Lock()
|
||||
cachedErr := impl.cachedError
|
||||
retryAttempt := impl.retryAttempt
|
||||
nextRetryTime := impl.nextRetryTime
|
||||
impl.mu.Unlock()
|
||||
|
||||
if cachedErr == nil {
|
||||
t.Error("error should be cached internally after failed file read")
|
||||
}
|
||||
if retryAttempt != 1 {
|
||||
t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt)
|
||||
}
|
||||
if nextRetryTime.IsZero() || nextRetryTime.Before(time.Now()) {
|
||||
t.Error("Next retry time should be set to future time")
|
||||
}
|
||||
|
||||
// Second call should still return cached error.
|
||||
_, err2 := creds.GetRequestMetadata(ctx)
|
||||
if err2 == nil {
|
||||
t.Fatal("Expected cached error")
|
||||
}
|
||||
if status.Code(err2) != codes.Unavailable {
|
||||
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err2))
|
||||
}
|
||||
if err1.Error() != err2.Error() {
|
||||
t.Errorf("cached error = %q, want %q", err2.Error(), err1.Error())
|
||||
}
|
||||
|
||||
impl.mu.Lock()
|
||||
retryAttempt2 := impl.retryAttempt
|
||||
nextRetryTime2 := impl.nextRetryTime
|
||||
impl.mu.Unlock()
|
||||
|
||||
if !nextRetryTime2.Equal(nextRetryTime) {
|
||||
t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, nextRetryTime)
|
||||
}
|
||||
if retryAttempt2 != 1 {
|
||||
t.Error("retry attempt should not change due to backoff")
|
||||
}
|
||||
|
||||
// Fast-forward the backoff retry time to allow next retry attempt.
|
||||
impl.mu.Lock()
|
||||
impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
|
||||
impl.mu.Unlock()
|
||||
|
||||
// Third call should retry but still fail with UNAVAILABLE.
|
||||
_, err3 := creds.GetRequestMetadata(ctx)
|
||||
if err3 == nil {
|
||||
t.Fatal("Expected cached error")
|
||||
}
|
||||
if status.Code(err3) != codes.Unavailable {
|
||||
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err3))
|
||||
}
|
||||
if err3.Error() != err1.Error() {
|
||||
t.Errorf("cached error = %q, want %q", err3.Error(), err1.Error())
|
||||
}
|
||||
|
||||
impl.mu.Lock()
|
||||
retryAttempt3 := impl.retryAttempt
|
||||
nextRetryTime3 := impl.nextRetryTime
|
||||
impl.mu.Unlock()
|
||||
|
||||
if !nextRetryTime3.After(nextRetryTime2) {
|
||||
t.Error("nextRetryTime should not change due to backoff")
|
||||
}
|
||||
if retryAttempt3 != 2 {
|
||||
t.Error("retry attempt should not change due to backoff")
|
||||
}
|
||||
|
||||
// Create valid token file.
|
||||
validToken := createTestJWT(t, "", time.Now().Add(time.Hour))
|
||||
if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil {
|
||||
t.Fatalf("Failed to create valid token file: %v", err)
|
||||
}
|
||||
|
||||
// Fourth call should still fail even though the file now exists.
|
||||
_, err4 := creds.GetRequestMetadata(ctx)
|
||||
if err4 == nil {
|
||||
t.Fatal("Expected cached error")
|
||||
}
|
||||
if status.Code(err4) != codes.Unavailable {
|
||||
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err4))
|
||||
}
|
||||
if err4.Error() != err3.Error() {
|
||||
t.Errorf("cached error = %q, want %q", err4.Error(), err3.Error())
|
||||
}
|
||||
|
||||
impl.mu.Lock()
|
||||
retryAttempt4 := impl.retryAttempt
|
||||
nextRetryTime4 := impl.nextRetryTime
|
||||
impl.mu.Unlock()
|
||||
|
||||
if !nextRetryTime4.Equal(nextRetryTime3) {
|
||||
t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, nextRetryTime3)
|
||||
}
|
||||
if retryAttempt4 != retryAttempt3 {
|
||||
t.Error("retry attempt should not change due to backoff")
|
||||
}
|
||||
|
||||
// Fast-forward the backoff retry time to allow next retry attempt.
|
||||
impl.mu.Lock()
|
||||
impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
|
||||
impl.mu.Unlock()
|
||||
// Fifth call should succeed since the file now exists
|
||||
// and the backoff has expired.
|
||||
_, err5 := creds.GetRequestMetadata(ctx)
|
||||
if err5 != nil {
|
||||
t.Errorf("after creating valid token file, GetRequestMetadata() should eventually succeed, but got: %v", err5)
|
||||
t.Error("backoff should expire and trigger new attempt on next RPC")
|
||||
} else {
|
||||
// If successful, verify error cache and backoff state were cleared.
|
||||
impl.mu.Lock()
|
||||
clearedErr := impl.cachedError
|
||||
retryAttempt := impl.retryAttempt
|
||||
nextRetryTime := impl.nextRetryTime
|
||||
impl.mu.Unlock()
|
||||
|
||||
if clearedErr != nil {
|
||||
t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr)
|
||||
}
|
||||
if retryAttempt != 0 {
|
||||
t.Errorf("after successful retry, retry attempt should be reset, got: %d", retryAttempt)
|
||||
}
|
||||
if !nextRetryTime.IsZero() {
|
||||
t.Error("after successful retry, next retry time should be cleared")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createTestJWT creates a test JWT token with the specified audience and
|
||||
// expiration.
|
||||
func createTestJWT(t *testing.T, audience string, expiration time.Time) string {
|
||||
t.Helper()
|
||||
|
||||
header := map[string]any{
|
||||
"typ": "JWT",
|
||||
"alg": "HS256",
|
||||
}
|
||||
|
||||
claims := map[string]any{}
|
||||
if audience != "" {
|
||||
claims["aud"] = audience
|
||||
}
|
||||
if !expiration.IsZero() {
|
||||
claims["exp"] = expiration.Unix()
|
||||
}
|
||||
|
||||
headerBytes, err := json.Marshal(header)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal header: %v", err)
|
||||
}
|
||||
|
||||
claimsBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal claims: %v", err)
|
||||
}
|
||||
|
||||
headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
|
||||
claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes)
|
||||
|
||||
// Remove padding for URL-safe base64
|
||||
headerB64 = strings.TrimRight(headerB64, "=")
|
||||
claimsB64 = strings.TrimRight(claimsB64, "=")
|
||||
|
||||
// For testing, we'll use a fake signature
|
||||
signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
|
||||
signature = strings.TrimRight(signature, "=")
|
||||
|
||||
return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature)
|
||||
}
|
||||
|
||||
func writeTempFile(t *testing.T, name, content string) string {
|
||||
t.Helper()
|
||||
tempDir := t.TempDir()
|
||||
filePath := filepath.Join(tempDir, name)
|
||||
if err := os.WriteFile(filePath, []byte(content), 0600); err != nil {
|
||||
t.Fatalf("Failed to write temp file: %v", err)
|
||||
}
|
||||
return filePath
|
||||
}
|
Loading…
Reference in New Issue