mirror of https://github.com/grpc/grpc-go.git
288 lines
9.4 KiB
Go
288 lines
9.4 KiB
Go
/*
|
|
*
|
|
* Copyright 2025 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
// Package jwt implements gRPC credentials using JWT tokens from files.
|
|
package jwt
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/internal/backoff"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// jwtClaims represents the JWT claims structure for extracting expiration time.
|
|
type jwtClaims struct {
|
|
Exp int64 `json:"exp"`
|
|
}
|
|
|
|
// jwtTokenFileCallCreds provides JWT token-based PerRPCCredentials that reads
|
|
// tokens from a file.
|
|
// This implementation follows the A97 JWT Call Credentials specification.
|
|
type jwtTokenFileCallCreds struct {
|
|
tokenFilePath string
|
|
|
|
// Cached token data
|
|
mu sync.RWMutex
|
|
cachedToken string
|
|
cachedExpiration time.Time // Slightly reduced expiration time compared to the actual exp
|
|
|
|
// Error caching with backoff
|
|
cachedError error // Cached error from last failed attempt
|
|
cachedErrorTime time.Time // When the error was cached
|
|
backoffStrategy backoff.Strategy // Backoff strategy when error occurs
|
|
retryAttempt int // Current retry attempt number
|
|
nextRetryTime time.Time // When next retry is allowed
|
|
|
|
// Pre-emptive refresh mutex
|
|
refreshMu sync.Mutex
|
|
}
|
|
|
|
// NewTokenFileCallCredentials creates PerRPCCredentials that reads JWT tokens
|
|
// from the specified file path.
|
|
//
|
|
// tokenFilePath is the filepath to the JWT token file.
|
|
func NewTokenFileCallCredentials(tokenFilePath string) (credentials.PerRPCCredentials, error) {
|
|
if tokenFilePath == "" {
|
|
return nil, fmt.Errorf("tokenFilePath cannot be empty")
|
|
}
|
|
|
|
return &jwtTokenFileCallCreds{
|
|
tokenFilePath: tokenFilePath,
|
|
backoffStrategy: backoff.DefaultExponential,
|
|
}, nil
|
|
}
|
|
|
|
// GetRequestMetadata gets the current request metadata, refreshing tokens
|
|
// if required. This implementation follows the PerRPCCredentials interface.
|
|
// The tokens will get automatically refreshed if they are about to expire or if
|
|
// they haven't been loaded successfully yet.
|
|
// If it's not possible to extract a token from the file, UNAVAILABLE is returned.
|
|
// If the token is extracted but invalid, then UNAUTHENTICATED is returned.
|
|
// If errors are encoutered, a backoff is applied before retrying.
|
|
func (c *jwtTokenFileCallCreds) GetRequestMetadata(ctx context.Context, _ ...string) (map[string]string, error) {
|
|
ri, _ := credentials.RequestInfoFromContext(ctx)
|
|
if err := credentials.CheckSecurityLevel(ri.AuthInfo, credentials.PrivacyAndIntegrity); err != nil {
|
|
return nil, fmt.Errorf("unable to transfer JWT token file PerRPCCredentials: %v", err)
|
|
}
|
|
|
|
// this may be delayed if the token needs to be refreshed from file
|
|
token, err := c.getToken(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return map[string]string{
|
|
"authorization": "Bearer " + token,
|
|
}, nil
|
|
}
|
|
|
|
// RequireTransportSecurity indicates whether the credentials requires
|
|
// transport security.
|
|
func (c *jwtTokenFileCallCreds) RequireTransportSecurity() bool {
|
|
return true
|
|
}
|
|
|
|
// getToken returns a valid JWT token, reading from file if necessary.
|
|
// Implements pre-emptive refresh and caches errors with backoff.
|
|
func (c *jwtTokenFileCallCreds) getToken(ctx context.Context) (string, error) {
|
|
c.mu.RLock()
|
|
|
|
if c.isTokenValid() {
|
|
token := c.cachedToken
|
|
shouldRefresh := c.needsPreemptiveRefresh()
|
|
c.mu.RUnlock()
|
|
|
|
if shouldRefresh {
|
|
c.triggerPreemptiveRefresh()
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
// if still within backoff period, return cached error to avoid repeated file reads
|
|
if c.cachedError != nil && time.Now().Before(c.nextRetryTime) {
|
|
err := c.cachedError
|
|
c.mu.RUnlock()
|
|
return "", err
|
|
}
|
|
|
|
c.mu.RUnlock()
|
|
// Token is expired or missing or the retry backoff period has expired. So
|
|
// refresh synchronously.
|
|
// NOTE: refreshTokenSync itself acquires the write lock
|
|
return c.refreshTokenSync(ctx, false)
|
|
}
|
|
|
|
// isTokenValid checks if the cached token is still valid.
|
|
// Caller must hold c.mu.RLock().
|
|
func (c *jwtTokenFileCallCreds) isTokenValid() bool {
|
|
if c.cachedToken == "" {
|
|
return false
|
|
}
|
|
return c.cachedExpiration.After(time.Now())
|
|
}
|
|
|
|
// needsPreemptiveRefresh checks if a pre-emptive refresh should be triggered.
|
|
// Returns true if the cached token is valid but expires within 1 minute.
|
|
// We only trigger pre-emptive refresh for valid tokens - if the token is invalid
|
|
// or expired, the next RPC will handle synchronous refresh instead.
|
|
// Caller must hold c.mu.RLock().
|
|
func (c *jwtTokenFileCallCreds) needsPreemptiveRefresh() bool {
|
|
return c.isTokenValid() && time.Until(c.cachedExpiration) < time.Minute
|
|
}
|
|
|
|
// triggerPreemptiveRefresh starts a background refresh if needed.
|
|
// Multiple concurrent calls are safe - only one refresh will run at a time.
|
|
// The refresh runs in a separate goroutine and does not block the caller.
|
|
func (c *jwtTokenFileCallCreds) triggerPreemptiveRefresh() {
|
|
go func() {
|
|
c.refreshMu.Lock()
|
|
defer c.refreshMu.Unlock()
|
|
|
|
// Re-check if refresh is still needed under mutex
|
|
c.mu.RLock()
|
|
stillNeeded := c.needsPreemptiveRefresh()
|
|
c.mu.RUnlock()
|
|
|
|
if !stillNeeded {
|
|
return // Another goroutine already refreshed or token expired
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
// Force refresh to read new token even if current one is still valid
|
|
_, _ = c.refreshTokenSync(ctx, true)
|
|
}()
|
|
}
|
|
|
|
// refreshTokenSync reads a new token from the file and updates the cache. If
|
|
// preemptiveRefresh is true, bypasses the validity check of the currently cached
|
|
// token and always reads from file.
|
|
// This is used for pre-emptive refresh to ensure new tokens are loaded even when
|
|
// the cached token is still valid. If preemptiveRefresh is false, skips file read
|
|
// when cached token is still valid, optimizing concurrent synchronous refresh calls
|
|
// where one RPC may have already updated the cache while another was waiting on the lock.
|
|
func (c *jwtTokenFileCallCreds) refreshTokenSync(_ context.Context, preemptiveRefresh bool) (string, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
// Double-check under write lock but skip if preemptive refresh is requested
|
|
if !preemptiveRefresh && c.isTokenValid() {
|
|
return c.cachedToken, nil
|
|
}
|
|
|
|
tokenBytes, err := os.ReadFile(c.tokenFilePath)
|
|
if err != nil {
|
|
err = status.Errorf(codes.Unavailable, "failed to read token file %q: %v", c.tokenFilePath, err)
|
|
c.setErrorWithBackoff(err)
|
|
return "", err
|
|
}
|
|
|
|
token := strings.TrimSpace(string(tokenBytes))
|
|
if token == "" {
|
|
err := status.Errorf(codes.Unavailable, "token file %q is empty", c.tokenFilePath)
|
|
c.setErrorWithBackoff(err)
|
|
return "", err
|
|
}
|
|
|
|
// Parse JWT to extract expiration
|
|
exp, err := c.extractExpiration(token)
|
|
if err != nil {
|
|
err = status.Errorf(codes.Unauthenticated, "failed to parse JWT from token file %q: %v", c.tokenFilePath, err)
|
|
c.setErrorWithBackoff(err)
|
|
return "", err
|
|
}
|
|
|
|
// Success - clear any cached error and backoff state, update token cache
|
|
c.clearErrorAndBackoff()
|
|
c.cachedToken = token
|
|
// Per RFC A97: consider token invalid if it expires within the next 30
|
|
// seconds to accommodate for clock skew and server processing time.
|
|
c.cachedExpiration = exp.Add(-30 * time.Second)
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// extractExpiration parses the JWT token to extract the expiration time.
|
|
func (c *jwtTokenFileCallCreds) extractExpiration(token string) (time.Time, error) {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 3 {
|
|
return time.Time{}, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
|
}
|
|
|
|
// Decode the payload (second part)
|
|
payload := parts[1]
|
|
|
|
// Add padding if necessary for base64 decoding
|
|
for len(payload)%4 != 0 {
|
|
payload += "="
|
|
}
|
|
|
|
payloadBytes, err := base64.URLEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
return time.Time{}, fmt.Errorf("failed to decode JWT payload: %v", err)
|
|
}
|
|
|
|
var claims jwtClaims
|
|
if err := json.Unmarshal(payloadBytes, &claims); err != nil {
|
|
return time.Time{}, fmt.Errorf("failed to unmarshal JWT claims: %v", err)
|
|
}
|
|
|
|
if claims.Exp == 0 {
|
|
return time.Time{}, fmt.Errorf("JWT token has no expiration claim")
|
|
}
|
|
|
|
expTime := time.Unix(claims.Exp, 0)
|
|
|
|
// Check if token is already expired
|
|
if expTime.Before(time.Now()) {
|
|
return time.Time{}, fmt.Errorf("JWT token is expired")
|
|
}
|
|
|
|
return expTime, nil
|
|
}
|
|
|
|
// setErrorWithBackoff caches an error and calculates the next retry time using exponential backoff.
|
|
// Caller must hold c.mu write lock.
|
|
func (c *jwtTokenFileCallCreds) setErrorWithBackoff(err error) {
|
|
c.cachedError = err
|
|
c.cachedErrorTime = time.Now()
|
|
c.retryAttempt++
|
|
backoffDelay := c.backoffStrategy.Backoff(c.retryAttempt - 1)
|
|
c.nextRetryTime = time.Now().Add(backoffDelay)
|
|
}
|
|
|
|
// clearErrorAndBackoff clears the cached error and resets backoff state.
|
|
// Caller must hold c.mu write lock.
|
|
func (c *jwtTokenFileCallCreds) clearErrorAndBackoff() {
|
|
c.cachedError = nil
|
|
c.cachedErrorTime = time.Time{}
|
|
c.retryAttempt = 0
|
|
c.nextRetryTime = time.Time{}
|
|
}
|