mirror of https://github.com/grpc/grpc-go.git
763 lines
23 KiB
Go
763 lines
23 KiB
Go
/*
|
|
*
|
|
* Copyright 2025 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
package jwt
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/internal/grpctest"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type s struct {
|
|
grpctest.Tester
|
|
}
|
|
|
|
func Test(t *testing.T) {
|
|
grpctest.RunSubTests(t, s{})
|
|
}
|
|
|
|
func (s) TestNewTokenFileCallCredentials(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tokenFilePath string
|
|
wantErr bool
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "valid parameters",
|
|
tokenFilePath: "/path/to/token",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "empty token file path",
|
|
tokenFilePath: "",
|
|
wantErr: true,
|
|
wantErrContains: "tokenFilePath cannot be empty",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
creds, err := NewTokenFileCallCredentials(tt.tokenFilePath)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() expected error, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), tt.wantErrContains) {
|
|
t.Fatalf("NewTokenFileCallCredentials() error = %v, want error containing %q", err, tt.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() unexpected error: %v", err)
|
|
}
|
|
if creds == nil {
|
|
t.Fatal("NewTokenFileCallCredentials() returned nil credentials")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestTokenFileCallCreds_RequireTransportSecurity(t *testing.T) {
|
|
creds, err := NewTokenFileCallCredentials("/path/to/token")
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
if !creds.RequireTransportSecurity() {
|
|
t.Error("RequireTransportSecurity() = false, want true")
|
|
}
|
|
}
|
|
|
|
func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "jwt_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
now := time.Now().Truncate(time.Second)
|
|
tests := []struct {
|
|
name string
|
|
tokenContent string
|
|
authInfo credentials.AuthInfo
|
|
wantErr bool
|
|
wantErrContains string
|
|
wantMetadata map[string]string
|
|
}{
|
|
{
|
|
name: "valid token without expiration errors",
|
|
tokenContent: createTestJWT(t, "", time.Time{}),
|
|
authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
wantErr: true,
|
|
wantErrContains: "JWT token has no expiration claim",
|
|
},
|
|
{
|
|
name: "valid token with future expiration succeeds",
|
|
tokenContent: createTestJWT(t, "https://example.com", now.Add(time.Hour)),
|
|
authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
wantErr: false,
|
|
wantMetadata: map[string]string{"authorization": "Bearer " + createTestJWT(t, "https://example.com", now.Add(time.Hour))},
|
|
},
|
|
{
|
|
name: "insufficient security level",
|
|
tokenContent: createTestJWT(t, "", time.Time{}),
|
|
authInfo: &testAuthInfo{secLevel: credentials.NoSecurity},
|
|
wantErr: true,
|
|
wantErrContains: "unable to transfer JWT token file PerRPCCredentials",
|
|
},
|
|
{
|
|
name: "expired token errors",
|
|
tokenContent: createTestJWT(t, "", now.Add(-time.Hour)),
|
|
authInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
wantErr: true,
|
|
wantErrContains: "JWT token is expired",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tokenFile := writeTempFile(t, "token", tt.tokenContent)
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: tt.authInfo,
|
|
})
|
|
|
|
metadata, err := creds.GetRequestMetadata(ctx)
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatalf("GetRequestMetadata() expected error, got nil")
|
|
}
|
|
if !strings.Contains(err.Error(), tt.wantErrContains) {
|
|
t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("GetRequestMetadata() unexpected error: %v", err)
|
|
}
|
|
|
|
if len(metadata) != len(tt.wantMetadata) {
|
|
t.Fatalf("GetRequestMetadata() returned %d metadata entries, want %d", len(metadata), len(tt.wantMetadata))
|
|
}
|
|
|
|
for k, v := range tt.wantMetadata {
|
|
if metadata[k] != v {
|
|
t.Errorf("GetRequestMetadata() metadata[%q] = %q, want %q", k, metadata[k], v)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) {
|
|
|
|
token := createTestJWT(t, "", time.Now().Add(time.Hour))
|
|
tokenFile := writeTempFile(t, "token", token)
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
// First call should read from file
|
|
metadata1, err := creds.GetRequestMetadata(ctx)
|
|
if err != nil {
|
|
t.Fatalf("First GetRequestMetadata() failed: %v", err)
|
|
}
|
|
|
|
// Update the file with a different token
|
|
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
|
|
if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
|
|
t.Fatalf("Failed to update token file: %v", err)
|
|
}
|
|
|
|
// Second call should return cached token (not the updated one)
|
|
metadata2, err := creds.GetRequestMetadata(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
|
|
}
|
|
|
|
if metadata1["authorization"] != metadata2["authorization"] {
|
|
t.Error("Expected cached token to be returned, but got different token")
|
|
}
|
|
}
|
|
|
|
func (s) TestTokenFileCallCreds_FileErrors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupFile func(string) error
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "nonexistent file",
|
|
setupFile: func(_ string) error {
|
|
return nil // Don't create the file
|
|
},
|
|
wantErrContains: "failed to read token file",
|
|
},
|
|
{
|
|
name: "empty file",
|
|
setupFile: func(path string) error {
|
|
return os.WriteFile(path, []byte(""), 0600)
|
|
},
|
|
wantErrContains: "token file",
|
|
},
|
|
{
|
|
name: "file with whitespace only",
|
|
setupFile: func(path string) error {
|
|
return os.WriteFile(path, []byte(" \n\t "), 0600)
|
|
},
|
|
wantErrContains: "token file",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tempDir, err := os.MkdirTemp("", "jwt_test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
tokenFile := filepath.Join(tempDir, "token")
|
|
if err := tt.setupFile(tokenFile); err != nil {
|
|
t.Fatalf("Failed to setup test file: %v", err)
|
|
}
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
_, err = creds.GetRequestMetadata(ctx)
|
|
if err == nil {
|
|
t.Fatal("GetRequestMetadata() expected error, got nil")
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), tt.wantErrContains) {
|
|
t.Fatalf("GetRequestMetadata() error = %v, want error containing %q", err, tt.wantErrContains)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// testAuthInfo implements credentials.AuthInfo for testing.
|
|
type testAuthInfo struct {
|
|
secLevel credentials.SecurityLevel
|
|
}
|
|
|
|
func (t *testAuthInfo) AuthType() string {
|
|
return "test"
|
|
}
|
|
|
|
func (t *testAuthInfo) GetCommonAuthInfo() credentials.CommonAuthInfo {
|
|
return credentials.CommonAuthInfo{SecurityLevel: t.secLevel}
|
|
}
|
|
|
|
// createTestJWT creates a test JWT token with the specified audience and expiration.
|
|
func createTestJWT(t *testing.T, audience string, expiration time.Time) string {
|
|
t.Helper()
|
|
|
|
header := map[string]any{
|
|
"typ": "JWT",
|
|
"alg": "HS256",
|
|
}
|
|
|
|
claims := map[string]any{}
|
|
if audience != "" {
|
|
claims["aud"] = audience
|
|
}
|
|
if !expiration.IsZero() {
|
|
claims["exp"] = expiration.Unix()
|
|
}
|
|
|
|
headerBytes, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal header: %v", err)
|
|
}
|
|
|
|
claimsBytes, err := json.Marshal(claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to marshal claims: %v", err)
|
|
}
|
|
|
|
headerB64 := base64.URLEncoding.EncodeToString(headerBytes)
|
|
claimsB64 := base64.URLEncoding.EncodeToString(claimsBytes)
|
|
|
|
// Remove padding for URL-safe base64
|
|
headerB64 = strings.TrimRight(headerB64, "=")
|
|
claimsB64 = strings.TrimRight(claimsB64, "=")
|
|
|
|
// For testing, we'll use a fake signature
|
|
signature := base64.URLEncoding.EncodeToString([]byte("fake_signature"))
|
|
signature = strings.TrimRight(signature, "=")
|
|
|
|
return fmt.Sprintf("%s.%s.%s", headerB64, claimsB64, signature)
|
|
}
|
|
|
|
// Tests that cached token expiration is set to 30 seconds before actual token expiration.
|
|
func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) {
|
|
// Create token that expires in 2 hours
|
|
tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour)
|
|
token := createTestJWT(t, "", tokenExp)
|
|
tokenFile := writeTempFile(t, "token", token)
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
// Get token to trigger caching
|
|
_, err = creds.GetRequestMetadata(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GetRequestMetadata() failed: %v", err)
|
|
}
|
|
|
|
// Verify cached expiration is 30 seconds before actual token expiration
|
|
impl := creds.(*jwtTokenFileCallCreds)
|
|
impl.mu.RLock()
|
|
cachedExp := impl.cachedExpiration
|
|
impl.mu.RUnlock()
|
|
|
|
expectedExp := tokenExp.Add(-30 * time.Second)
|
|
if !cachedExp.Equal(expectedExp) {
|
|
t.Errorf("cache expiration = %v, want %v", cachedExp, expectedExp)
|
|
}
|
|
}
|
|
|
|
// Tests that pre-emptive refresh is triggered within 1 minute of expiration.
|
|
func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) {
|
|
// Create token that expires in 80 seconds (=> cache expires in ~50s)
|
|
// This ensures pre-emptive refresh triggers since 50s < the 1 minute check
|
|
tokenExp := time.Now().Add(80 * time.Second)
|
|
expiringToken := createTestJWT(t, "", tokenExp)
|
|
tokenFile := writeTempFile(t, "token", expiringToken)
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
// Get token - should trigger pre-emptive refresh
|
|
metadata1, err := creds.GetRequestMetadata(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GetRequestMetadata() failed: %v", err)
|
|
}
|
|
|
|
// Verify token was cached and check if refresh should be triggered
|
|
impl := creds.(*jwtTokenFileCallCreds)
|
|
impl.mu.RLock()
|
|
cacheExp := impl.cachedExpiration
|
|
tokenCached := impl.cachedToken != ""
|
|
shouldTriggerRefresh := impl.needsPreemptiveRefresh()
|
|
impl.mu.RUnlock()
|
|
|
|
if !tokenCached {
|
|
t.Error("token should be cached after successful GetRequestMetadata")
|
|
}
|
|
|
|
if !shouldTriggerRefresh {
|
|
timeUntilExp := time.Until(cacheExp)
|
|
t.Errorf("cache expires in %v, should be < 1 minute to trigger pre-emptive refresh", timeUntilExp)
|
|
}
|
|
|
|
// Create new token file with different expiration while refresh is happening
|
|
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
|
|
if err := os.WriteFile(tokenFile, []byte(newToken), 0600); err != nil {
|
|
t.Fatalf("Failed to write updated token file: %v", err)
|
|
}
|
|
|
|
// Get token again - should trigger a refresh given that the first one was
|
|
// cached but expiring soon
|
|
// However, the function should have returned right away with the current cached token
|
|
metadata2, err := creds.GetRequestMetadata(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
|
|
}
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// now should get the new token
|
|
metadata3, err := creds.GetRequestMetadata(ctx)
|
|
if err != nil {
|
|
t.Fatalf("Second GetRequestMetadata() failed: %v", err)
|
|
}
|
|
|
|
// If pre-emptive refresh worked, we should get the new token
|
|
expectedAuth1 := "Bearer " + expiringToken
|
|
expectedAuth2 := "Bearer " + expiringToken
|
|
expectedAuth3 := "Bearer " + newToken
|
|
|
|
actualAuth1 := metadata1["authorization"]
|
|
actualAuth2 := metadata2["authorization"]
|
|
actualAuth3 := metadata3["authorization"]
|
|
|
|
if actualAuth1 != expectedAuth1 {
|
|
t.Errorf("First call should return original token: got %q, want %q", actualAuth1, expectedAuth1)
|
|
}
|
|
|
|
if actualAuth2 != expectedAuth2 {
|
|
t.Errorf("Second call should return the original token: got %q, want %q", actualAuth2, expectedAuth2)
|
|
}
|
|
if actualAuth3 != expectedAuth3 {
|
|
t.Errorf("Third call should return the original token: got %q, want %q", actualAuth3, expectedAuth3)
|
|
}
|
|
}
|
|
|
|
// Tests that backoff behavior handles file read errors correctly.
|
|
func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) {
|
|
// This test has the following flow:
|
|
// First call to GetRequestMetadata() fails with UNAVAILABLE due to a missing file.
|
|
// Second call to GetRequestMetadata() fails with UNAVAILABLE due backoff.
|
|
// Third call to GetRequestMetadata() fails with UNAVAILABLE due to retry.
|
|
// Fourth call to GetRequestMetadata() fails with UNAVAILABLE due to backoff even though file exists.
|
|
// Fifth call to GetRequestMetadata() succeeds after creating the file.
|
|
tempDir := t.TempDir()
|
|
nonExistentFile := filepath.Join(tempDir, "nonexistent")
|
|
|
|
creds, err := NewTokenFileCallCredentials(nonExistentFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
// First call should fail with UNAVAILABLE
|
|
_, err1 := creds.GetRequestMetadata(ctx)
|
|
if err1 == nil {
|
|
t.Fatal("Expected error from nonexistent file")
|
|
}
|
|
if status.Code(err1) != codes.Unavailable {
|
|
t.Fatalf("GetRequestMetadata() = %v, want UNAVAILABLE", status.Code(err1))
|
|
}
|
|
|
|
// Verify error is cached internally
|
|
impl := creds.(*jwtTokenFileCallCreds)
|
|
impl.mu.RLock()
|
|
cachedErr := impl.cachedError
|
|
cachedErrTime := impl.cachedErrorTime
|
|
retryAttempt := impl.retryAttempt
|
|
nextRetryTime := impl.nextRetryTime
|
|
impl.mu.RUnlock()
|
|
|
|
if cachedErr == nil {
|
|
t.Error("error should be cached internally after failed file read")
|
|
}
|
|
if cachedErrTime.IsZero() {
|
|
t.Error("error cache time should be set")
|
|
}
|
|
if retryAttempt != 1 {
|
|
t.Errorf("Expected retry attempt to be 1, got %d", retryAttempt)
|
|
}
|
|
if nextRetryTime.IsZero() || nextRetryTime.Before(time.Now()) {
|
|
t.Error("Next retry time should be set to future time")
|
|
}
|
|
|
|
// Second call should still return cached error
|
|
_, err2 := creds.GetRequestMetadata(ctx)
|
|
if err2 == nil {
|
|
t.Fatal("Expected cached error")
|
|
}
|
|
if status.Code(err2) != codes.Unavailable {
|
|
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err2))
|
|
}
|
|
if err1.Error() != err2.Error() {
|
|
t.Errorf("cached error = %q, want %q", err2.Error(), err1.Error())
|
|
}
|
|
|
|
impl.mu.RLock()
|
|
retryAttempt2 := impl.retryAttempt
|
|
nextRetryTime2 := impl.nextRetryTime
|
|
impl.mu.RUnlock()
|
|
|
|
if !nextRetryTime2.Equal(nextRetryTime) {
|
|
t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime2, nextRetryTime)
|
|
}
|
|
if retryAttempt2 != 1 {
|
|
t.Error("retry attempt should not change due to backoff")
|
|
}
|
|
|
|
// fast-forward the backoff retry time to allow next retry attempt
|
|
impl.mu.Lock()
|
|
impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
|
|
impl.mu.Unlock()
|
|
|
|
// Third call should retry but still fail with UNAVAILABLE
|
|
_, err3 := creds.GetRequestMetadata(ctx)
|
|
if err3 == nil {
|
|
t.Fatal("Expected cached error")
|
|
}
|
|
if status.Code(err3) != codes.Unavailable {
|
|
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err3))
|
|
}
|
|
if err3.Error() != err1.Error() {
|
|
t.Errorf("cached error = %q, want %q", err3.Error(), err1.Error())
|
|
}
|
|
|
|
impl.mu.RLock()
|
|
retryAttempt3 := impl.retryAttempt
|
|
nextRetryTime3 := impl.nextRetryTime
|
|
impl.mu.RUnlock()
|
|
|
|
if !nextRetryTime3.After(nextRetryTime2) {
|
|
t.Error("nextRetryTime should not change due to backoff")
|
|
}
|
|
if retryAttempt3 != 2 {
|
|
t.Error("retry attempt should not change due to backoff")
|
|
}
|
|
|
|
// Create valid token file
|
|
validToken := createTestJWT(t, "", time.Now().Add(time.Hour))
|
|
if err := os.WriteFile(nonExistentFile, []byte(validToken), 0600); err != nil {
|
|
t.Fatalf("Failed to create valid token file: %v", err)
|
|
}
|
|
|
|
// Forth call should still fail even though the file now exists
|
|
_, err4 := creds.GetRequestMetadata(ctx)
|
|
if err4 == nil {
|
|
t.Fatal("Expected cached error")
|
|
}
|
|
if status.Code(err4) != codes.Unavailable {
|
|
t.Fatalf("GetRequestMetadata() = %v, want cached UNAVAILABLE", status.Code(err4))
|
|
}
|
|
if err4.Error() != err3.Error() {
|
|
t.Errorf("cached error = %q, want %q", err4.Error(), err3.Error())
|
|
}
|
|
|
|
impl.mu.RLock()
|
|
retryAttempt4 := impl.retryAttempt
|
|
nextRetryTime4 := impl.nextRetryTime
|
|
impl.mu.RUnlock()
|
|
|
|
if !nextRetryTime4.Equal(nextRetryTime3) {
|
|
t.Errorf("nextRetryTime should not change due to backoff. Got: %v, Want: %v", nextRetryTime4, nextRetryTime3)
|
|
}
|
|
if retryAttempt4 != retryAttempt3 {
|
|
t.Error("retry attempt should not change due to backoff")
|
|
}
|
|
|
|
// fast-forward the backoff retry time to allow next retry attempt
|
|
impl.mu.Lock()
|
|
impl.nextRetryTime = time.Now().Add(-1 * time.Minute)
|
|
impl.mu.Unlock()
|
|
// Fifth call should succeed since the file now exists
|
|
// and the backoff has expired
|
|
_, err5 := creds.GetRequestMetadata(ctx)
|
|
if err5 != nil {
|
|
t.Errorf("after creating valid token file, GetRequestMetadata() should eventually succeed, but got: %v", err5)
|
|
t.Error("backoff should expire and trigger new attempt on next RPC")
|
|
} else {
|
|
// If successful, verify error cache and backoff state were cleared
|
|
impl.mu.RLock()
|
|
clearedErr := impl.cachedError
|
|
clearedErrTime := impl.cachedErrorTime
|
|
retryAttempt := impl.retryAttempt
|
|
nextRetryTime := impl.nextRetryTime
|
|
impl.mu.RUnlock()
|
|
|
|
if clearedErr != nil {
|
|
t.Errorf("after successful retry, cached error should be cleared, got: %v", clearedErr)
|
|
}
|
|
if !clearedErrTime.IsZero() {
|
|
t.Error("after successful retry, cached error time should be cleared")
|
|
}
|
|
if retryAttempt != 0 {
|
|
t.Errorf("after successful retry, retry attempt should be reset, got: %d", retryAttempt)
|
|
}
|
|
if !nextRetryTime.IsZero() {
|
|
t.Error("after successful retry, next retry time should be cleared")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Tests that invalid JWT tokens are handled with UNAUTHENTICATED status.
|
|
func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) {
|
|
// Write invalid JWT (missing exp field)
|
|
invalidJWT := createTestJWT(t, "", time.Time{}) // No expiration
|
|
tokenFile := writeTempFile(t, "token", invalidJWT)
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
_, err = creds.GetRequestMetadata(ctx)
|
|
if err == nil {
|
|
t.Fatal("Expected UNAUTHENTICATED from invalid JWT")
|
|
}
|
|
if status.Code(err) != codes.Unauthenticated {
|
|
t.Errorf("GetRequestMetadata() = %v, want UNAUTHENTICATED for invalid JWT", status.Code(err))
|
|
}
|
|
}
|
|
|
|
// Tests that RPCs are queued during file operations and all receive the same result.
|
|
func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) {
|
|
tempDir := t.TempDir()
|
|
tokenFile := filepath.Join(tempDir, "token")
|
|
|
|
// Start with no token file to force file read during first RPC
|
|
creds, err := NewTokenFileCallCredentials(tokenFile)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
ctx = credentials.NewContextWithRequestInfo(ctx, credentials.RequestInfo{
|
|
AuthInfo: &testAuthInfo{secLevel: credentials.PrivacyAndIntegrity},
|
|
})
|
|
|
|
// Launch multiple concurrent RPCs before creating the token file
|
|
const numConcurrentRPCs = 5
|
|
results := make(chan error, numConcurrentRPCs)
|
|
|
|
for range numConcurrentRPCs {
|
|
go func() {
|
|
_, err := creds.GetRequestMetadata(ctx)
|
|
results <- err
|
|
}()
|
|
}
|
|
|
|
// Collect all results - they should all be the same error (UNAVAILABLE)
|
|
var errors []error
|
|
for range numConcurrentRPCs {
|
|
err := <-results
|
|
errors = append(errors, err)
|
|
}
|
|
|
|
// All RPCs should fail with the same error (file not found)
|
|
for i, err := range errors {
|
|
if err == nil {
|
|
t.Errorf("RPC %d should have failed with UNAVAILABLE", i)
|
|
continue
|
|
}
|
|
if status.Code(err) != codes.Unavailable {
|
|
t.Errorf("RPC %d = %v, want UNAVAILABLE", i, status.Code(err))
|
|
}
|
|
if i > 0 && err.Error() != errors[0].Error() {
|
|
t.Errorf("RPC %d error should match first RPC error for proper queueing", i)
|
|
}
|
|
}
|
|
|
|
// Verify error was cached after concurrent RPCs
|
|
impl := creds.(*jwtTokenFileCallCreds)
|
|
impl.mu.RLock()
|
|
finalCachedErr := impl.cachedError
|
|
impl.mu.RUnlock()
|
|
|
|
if finalCachedErr == nil {
|
|
t.Error("error should be cached after failed concurrent RPCs")
|
|
}
|
|
if finalCachedErr.Error() != errors[0].Error() {
|
|
t.Error("cached error should match the errors returned to RPCs")
|
|
}
|
|
}
|
|
|
|
// Tests that no background retries occur when channel is idle.
|
|
func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) {
|
|
newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour))
|
|
tokenFilepath := writeTempFile(t, "token", newToken)
|
|
|
|
creds, err := NewTokenFileCallCredentials(tokenFilepath)
|
|
if err != nil {
|
|
t.Fatalf("NewTokenFileCallCredentials() failed: %v", err)
|
|
}
|
|
|
|
impl := creds.(*jwtTokenFileCallCreds)
|
|
|
|
// Verify state unchanged - no background file reads attempted
|
|
impl.mu.RLock()
|
|
token := impl.cachedToken
|
|
cachedErr := impl.cachedError
|
|
impl.mu.RUnlock()
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
if token != "" {
|
|
t.Errorf("after idle period, cached token = %q, want empty (no background reads)", token)
|
|
}
|
|
if cachedErr != nil {
|
|
t.Errorf("after idle period, cached error = %v, want nil (no background reads)", cachedErr)
|
|
}
|
|
}
|
|
|
|
func writeTempFile(t *testing.T, name, content string) string {
|
|
t.Helper()
|
|
tempDir := t.TempDir()
|
|
filePath := filepath.Join(tempDir, name)
|
|
if err := os.WriteFile(filePath, []byte(content), 0600); err != nil {
|
|
t.Fatalf("Failed to write temp file: %v", err)
|
|
}
|
|
return filePath
|
|
}
|