From b18a1f561fa36d056aec1ff0e2a1dd604c74e45c Mon Sep 17 00:00:00 2001 From: Dimitar Pavlov Date: Sun, 6 Jul 2025 19:23:25 +0100 Subject: [PATCH] refactor test creation --- credentials/jwt/jwt_token_file.go | 4 +- credentials/jwt/jwt_token_file_test.go | 54 ++++++++------------------ 2 files changed, 18 insertions(+), 40 deletions(-) diff --git a/credentials/jwt/jwt_token_file.go b/credentials/jwt/jwt_token_file.go index 62b63e963..9d78e7beb 100644 --- a/credentials/jwt/jwt_token_file.go +++ b/credentials/jwt/jwt_token_file.go @@ -80,10 +80,10 @@ func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCreden // GetRequestMetadata gets the current request metadata, refreshing tokens // if required. This implementation follows the PerRPCCredentials interface. // The tokens will get automatically refreshed if they are about to expire or if -// they haven't been loaded successfully yet. In the latter case, a backoff is -// applied before retrying. +// they haven't been loaded successfully yet. // If it's not possible to extract a token from the file, UNAVAILABLE is returned. // If the token is extracted but invalid, then UNAUTHENTICATED is returned. +// If errors are encoutered, a backoff is applied before retrying. func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) { ri, _ := credentials.RequestInfoFromContext(ctx) if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil { diff --git a/credentials/jwt/jwt_token_file_test.go b/credentials/jwt/jwt_token_file_test.go index c611cc838..6405ef9c0 100644 --- a/credentials/jwt/jwt_token_file_test.go +++ b/credentials/jwt/jwt_token_file_test.go @@ -144,10 +144,7 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tokenFile := filepath.Join(tempDir, "token") - if err := os.WriteFile(tokenFile, []byte(tt.tokenContent), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", tt.tokenContent) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -189,18 +186,9 @@ func (s) TestTokenFileCallCreds_GetRequestMetadata(t *testing.T) { } func (s) TestTokenFileCallCreds_TokenCaching(t *testing.T) { - tempDir, err := os.MkdirTemp("", "jwt_test") - if err != nil { - t.Fatalf("Failed to create temp dir: %v", err) - } - defer os.RemoveAll(tempDir) - tokenFile := filepath.Join(tempDir, "token") token := createTestJWT(t, "", time.Now().Add(time.Hour)) - - if err := os.WriteFile(tokenFile, []byte(token), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", token) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -357,15 +345,10 @@ func createTestJWT(t *testing.T, audience string, expiration time.Time) string { // Tests that cached token expiration is set to 30 seconds before actual token expiration. func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - // Create token that expires in 2 hours tokenExp := time.Now().Truncate(time.Second).Add(2 * time.Hour) token := createTestJWT(t, "", tokenExp) - if err := os.WriteFile(tokenFile, []byte(token), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", token) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -398,16 +381,11 @@ func (s) TestTokenFileCallCreds_CacheExpirationIsBeforeTokenExpiration(t *testin // Tests that pre-emptive refresh is triggered within 1 minute of expiration. func (s) TestTokenFileCallCreds_PreemptiveRefreshIsTriggered(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - // Create token that expires in 80 seconds (=> cache expires in ~50s) // This ensures pre-emptive refresh triggers since 50s < the 1 minute check tokenExp := time.Now().Add(80 * time.Second) expiringToken := createTestJWT(t, "", tokenExp) - if err := os.WriteFile(tokenFile, []byte(expiringToken), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", expiringToken) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -658,14 +636,9 @@ func (s) TestTokenFileCallCreds_BackoffBehavior(t *testing.T) { // Tests that invalid JWT tokens are handled with UNAUTHENTICATED status. func (s) TestTokenFileCallCreds_InvalidJWTHandling(t *testing.T) { - tempDir := t.TempDir() - tokenFile := filepath.Join(tempDir, "token") - // Write invalid JWT (missing exp field) invalidJWT := createTestJWT(t, "", time.Time{}) // No expiration - if err := os.WriteFile(tokenFile, []byte(invalidJWT), 0600); err != nil { - t.Fatalf("Failed to write token file: %v", err) - } + tokenFile := writeTempFile(t, "token", invalidJWT) creds, err := NewTokenFileCallCredentials(tokenFile) if err != nil { @@ -752,13 +725,8 @@ func (s) TestTokenFileCallCreds_RPCQueueing(t *testing.T) { // Tests that no background retries occur when channel is idle. func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { - tempDir := t.TempDir() - tokenFilepath := filepath.Join(tempDir, "token") - newToken := createTestJWT(t, "", time.Now().Add(2*time.Hour)) - if err := os.WriteFile(tokenFilepath, []byte(newToken), 0600); err != nil { - t.Fatalf("Failed to write updated token file: %v", err) - } + tokenFilepath := writeTempFile(t, "token", newToken) creds, err := NewTokenFileCallCredentials(tokenFilepath) if err != nil { @@ -782,3 +750,13 @@ func (s) TestTokenFileCallCreds_NoIdleRetries(t *testing.T) { t.Errorf("after idle period, cached error = %v, want nil (no background reads)", cachedErr) } } + +func writeTempFile(t *testing.T, name, content string) string { + t.Helper() + tempDir := t.TempDir() + filePath := filepath.Join(tempDir, name) + if err := os.WriteFile(filePath, []byte(content), 0600); err != nil { + t.Fatalf("Failed to write temp file: %v", err) + } + return filePath +}