Merge pull request #2543 from danielaskdd/token-expire-renew

feat: Implement WebUI Token Auto-Renewal (Sliding Window Expiration)
This commit is contained in:
Daniel.y
2025-12-26 14:09:46 +08:00
committed by GitHub
7 changed files with 799 additions and 14 deletions

View File

@@ -49,6 +49,23 @@ OLLAMA_EMULATING_MODEL_TAG=latest
# GUEST_TOKEN_EXPIRE_HOURS=24
# JWT_ALGORITHM=HS256
### Token Auto-Renewal Configuration (Sliding Window Expiration)
### Enable automatic token renewal to prevent active users from being logged out
### When enabled, tokens will be automatically renewed when remaining time < threshold
# TOKEN_AUTO_RENEW=true
### Token renewal threshold (0.0 - 1.0)
### Renew token when remaining time < (total time * threshold)
### Default: 0.5 (renew when 50% time remaining)
### Examples:
### 0.5 = renew when 24h token has 12h left
### 0.25 = renew when 24h token has 6h left
# TOKEN_RENEW_THRESHOLD=0.5
### Note: Token renewal is automatically skipped for certain endpoints:
### - /health: Health check endpoint (no authentication required)
### - /documents/paginated: Frequently polled by client (5-30s interval)
### - /documents/pipeline_status: Very frequently polled by client (2s interval)
### - Rate limit: Minimum 60 seconds between renewals for same user
### API-Key to access LightRAG Server API
### Use this key in HTTP requests with the 'X-API-Key' header
### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query

View File

@@ -403,10 +403,14 @@ def parse_args() -> argparse.Namespace:
# For JWT Auth
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, float)
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, float)
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
# Token auto-renewal configuration (sliding window expiration)
args.token_auto_renew = get_env_value("TOKEN_AUTO_RENEW", True, bool)
args.token_renew_threshold = get_env_value("TOKEN_RENEW_THRESHOLD", 0.5, float)
# Rerank model configuration
args.rerank_model = get_env_value("RERANK_MODEL", None)
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)

View File

@@ -451,6 +451,9 @@ def create_app(args):
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=[
"X-New-Token"
], # Expose token renewal header for cross-origin requests
)
# Create combined auth dependency for all endpoints

View File

@@ -6,18 +6,39 @@ import os
import argparse
from typing import Optional, List, Tuple
import sys
import time
import logging
from ascii_colors import ASCIIColors
from lightrag.api import __api_version__ as api_version
from lightrag import __version__ as core_version
from lightrag.constants import (
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
)
from fastapi import HTTPException, Security, Request, status
from fastapi import HTTPException, Security, Request, Response, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from starlette.status import HTTP_403_FORBIDDEN
from .auth import auth_handler
from .config import ollama_server_infos, global_args, get_env_value
logger = logging.getLogger("lightrag")
# ========== Token Renewal Rate Limiting ==========
# Cache to track last renewal time per user (username as key)
# Format: {username: last_renewal_timestamp}
_token_renewal_cache: dict[str, float] = {}
_RENEWAL_MIN_INTERVAL = 60 # Minimum 60 seconds between renewals for same user
# ========== Token Renewal Path Exclusions ==========
# Paths that should NOT trigger token auto-renewal
# - /health: Health check endpoint, no login required
# - /documents/paginated: Client polls this frequently (5-30s), renewal not needed
# - /documents/pipeline_status: Client polls this very frequently (2s), renewal not needed
_TOKEN_RENEWAL_SKIP_PATHS = [
"/health",
"/documents/paginated",
"/documents/pipeline_status",
]
def check_env_file():
"""
@@ -87,6 +108,7 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
async def combined_dependency(
request: Request,
response: Response, # Added: needed to return new token via response header
token: str = Security(oauth2_scheme),
api_key_header_value: Optional[str] = None
if api_key_header is None
@@ -104,6 +126,80 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
if token:
try:
token_info = auth_handler.validate_token(token)
# ========== Token Auto-Renewal Logic ==========
from lightrag.api.config import global_args
from datetime import datetime
if global_args.token_auto_renew:
# Check if current path should skip token renewal
skip_renewal = any(
path == skip_path or path.startswith(skip_path + "/")
for skip_path in _TOKEN_RENEWAL_SKIP_PATHS
)
if skip_renewal:
logger.debug(f"Token auto-renewal skipped for path: {path}")
else:
try:
expire_time = token_info.get("exp")
if expire_time:
# Calculate remaining time ratio
now = datetime.utcnow()
remaining_seconds = (expire_time - now).total_seconds()
# Get original token expiration duration
role = token_info.get("role", "user")
total_hours = (
auth_handler.guest_expire_hours
if role == "guest"
else auth_handler.expire_hours
)
total_seconds = total_hours * 3600
# Issue new token if remaining time < threshold
if (
remaining_seconds
< total_seconds * global_args.token_renew_threshold
):
# ========== Rate Limiting Check ==========
username = token_info["username"]
current_time = time.time()
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = (
current_time - last_renewal
)
# Only renew if enough time has passed since last renewal
if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
new_token = auth_handler.create_token(
username=username,
role=role,
metadata=token_info.get("metadata", {}),
)
# Return new token via response header
response.headers["X-New-Token"] = new_token
# Update renewal cache
_token_renewal_cache[username] = current_time
# Optional: log renewal
logger.info(
f"Token auto-renewed for user {username} "
f"(role: {role}, remaining: {remaining_seconds:.0f}s)"
)
else:
# Log skip due to rate limit
logger.debug(
f"Token renewal skipped for {username} "
f"(rate limit: last renewal {time_since_last_renewal:.0f}s ago)"
)
# ========== End of Rate Limiting Check ==========
except Exception as e:
# Renewal failure should not affect normal request, just log
logger.warning(f"Token auto-renew failed: {e}")
# ========== End of Token Auto-Renewal Logic ==========
# Accept guest token if no auth is configured
if not auth_configured and token_info.get("role") == "guest":
return

View File

@@ -2,6 +2,7 @@ import axios, { AxiosError } from 'axios'
import { backendBaseUrl, popularLabelsDefaultLimit, searchLabelsDefaultLimit } from '@/lib/constants'
import { errorMessage } from '@/lib/utils'
import { useSettingsStore } from '@/stores/settings'
import { useAuthStore } from '@/stores/state'
import { navigationService } from '@/services/navigation'
// Types
@@ -285,8 +286,62 @@ const axiosInstance = axios.create({
}
})
// ========== Token Management ==========
// Prevent multiple requests from triggering token refresh simultaneously
let isRefreshingGuestToken = false;
let refreshTokenPromise: Promise<string> | null = null;
// Silent refresh for guest token
const silentRefreshGuestToken = async (): Promise<string> => {
// If already refreshing, return the same Promise
if (isRefreshingGuestToken && refreshTokenPromise) {
return refreshTokenPromise;
}
isRefreshingGuestToken = true;
refreshTokenPromise = (async () => {
try {
// Call /auth-status to get new guest token
const response = await axios.get('/auth-status', {
baseURL: backendBaseUrl,
// This request must skip the interceptor to avoid adding expired token
headers: { 'X-Skip-Interceptor': 'true' }
});
if (response.data.access_token && !response.data.auth_configured) {
const newToken = response.data.access_token;
// Update localStorage
localStorage.setItem('LIGHTRAG-API-TOKEN', newToken);
// Update auth state
useAuthStore.getState().login(
newToken,
true,
response.data.core_version,
response.data.api_version,
response.data.webui_title || null,
response.data.webui_description || null
);
return newToken;
} else {
throw new Error('Failed to get guest token');
}
} finally {
isRefreshingGuestToken = false;
refreshTokenPromise = null;
}
})();
return refreshTokenPromise;
};
// Interceptor: add api key and check authentication
axiosInstance.interceptors.request.use((config) => {
// Skip interceptor for token refresh requests
if (config.headers['X-Skip-Interceptor']) {
delete config.headers['X-Skip-Interceptor'];
return config;
}
const apiKey = useSettingsStore.getState().apiKey
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
@@ -300,20 +355,88 @@ axiosInstance.interceptors.request.use((config) => {
return config
})
// Interceptorhanle error
// Interceptorhandle token renewal and authentication errors
axiosInstance.interceptors.response.use(
(response) => response,
(error: AxiosError) => {
(response) => {
// ========== Check for new token from backend ==========
const newToken = response.headers['x-new-token'];
if (newToken) {
localStorage.setItem('LIGHTRAG-API-TOKEN', newToken);
// Optional: log in development mode
if (import.meta.env.DEV) {
console.log('[Auth] Token auto-renewed by backend');
}
// Update auth state with renewal tracking
try {
const payload = JSON.parse(atob(newToken.split('.')[1]));
const authStore = useAuthStore.getState();
if (authStore.isAuthenticated) {
// Track token renewal time and expiration
const renewalTime = Date.now();
const expiresAt = payload.exp ? payload.exp * 1000 : 0;
authStore.setTokenRenewal(renewalTime, expiresAt);
// Update username (usually unchanged, but just in case)
const newUsername = payload.sub;
if (newUsername && newUsername !== authStore.username) {
// Need to add setUsername method or just update via login
// For now, we'll skip username update as it's rare
}
}
} catch (error) {
console.warn('[Auth] Failed to parse renewed token:', error);
}
}
// ========== End of token renewal check ==========
return response;
},
async (error: AxiosError) => {
if (error.response) {
if (error.response?.status === 401) {
// For login API, throw error directly
if (error.config?.url?.includes('/login')) {
const originalRequest = error.config;
// 1. For login API, throw error directly
if (originalRequest?.url?.includes('/login')) {
throw error;
}
// For other APIs, navigate to login page
navigationService.navigateToLogin();
// return a reject Promise
// 2. Prevent infinite retry
if (originalRequest && (originalRequest as any)._retry) {
navigationService.navigateToLogin();
return Promise.reject(new Error('Authentication required'));
}
// 3. Check if in guest mode
const authStore = useAuthStore.getState();
const currentToken = localStorage.getItem('LIGHTRAG-API-TOKEN');
const isGuest = currentToken && authStore.isGuestMode;
// 4. Guest mode: silent refresh and retry
if (isGuest && originalRequest) {
try {
const newToken = await silentRefreshGuestToken();
// Mark as retried to prevent infinite loop
(originalRequest as any)._retry = true;
// Update token in request headers
originalRequest.headers['Authorization'] = `Bearer ${newToken}`;
// Retry original request
return axiosInstance(originalRequest);
} catch (refreshError) {
console.error('Failed to refresh guest token:', refreshError);
// Refresh failed, navigate to login
navigationService.navigateToLogin();
return Promise.reject(new Error('Failed to refresh authentication'));
}
}
// 5. Non-guest mode: navigate to login page
navigationService.navigateToLogin();
return Promise.reject(new Error('Authentication required'));
}
throw new Error(
@@ -418,7 +541,88 @@ export const queryTextStream = async (
if (!response.ok) {
// Handle 401 Unauthorized error specifically
if (response.status === 401) {
// For consistency with axios interceptor, navigate to login page
// Check if in guest mode
const authStore = useAuthStore.getState();
const currentToken = localStorage.getItem('LIGHTRAG-API-TOKEN');
const isGuest = currentToken && authStore.isGuestMode;
if (isGuest) {
try {
// Silent refresh token for guest mode
const newToken = await silentRefreshGuestToken();
// Retry stream request with new token
const retryHeaders = { ...headers };
retryHeaders['Authorization'] = `Bearer ${newToken}`;
const retryResponse = await fetch(`${backendBaseUrl}/query/stream`, {
method: 'POST',
headers: retryHeaders,
body: JSON.stringify(request),
});
if (!retryResponse.ok) {
throw new Error(`HTTP error! status: ${retryResponse.status}`);
}
// Retry successful, process stream response
// Re-execute the stream processing logic with retryResponse
if (!retryResponse.body) {
throw new Error('Response body is null');
}
const reader = retryResponse.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || '';
for (const line of lines) {
if (line.trim()) {
try {
const parsed = JSON.parse(line);
if (parsed.response) {
onChunk(parsed.response);
} else if (parsed.error) {
onError?.(parsed.error);
}
} catch (parseError) {
console.error('Failed to parse JSON:', parseError, 'Line:', line);
onError?.(`JSON parse error: ${parseError}`);
}
}
}
}
// Process any remaining data in buffer
if (buffer.trim()) {
try {
const parsed = JSON.parse(buffer);
if (parsed.response) {
onChunk(parsed.response);
} else if (parsed.error) {
onError?.(parsed.error);
}
} catch (parseError) {
console.error('Failed to parse final buffer:', parseError);
}
}
return; // Successfully completed retry
} catch (refreshError) {
console.error('Failed to refresh guest token for streaming:', refreshError);
navigationService.navigateToLogin();
throw new Error('Failed to refresh authentication');
}
}
// Non-guest mode: navigate to login page
navigationService.navigateToLogin();
// Create a specific authentication error

View File

@@ -33,11 +33,14 @@ interface AuthState {
username: string | null; // login username
webuiTitle: string | null; // Custom title
webuiDescription: string | null; // Title description
lastTokenRenewal: string | null; // Human-readable local time of last token renewal (for debugging and monitoring)
tokenExpiresAt: number | null; // Token expiration timestamp (extracted from JWT)
login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null, webuiTitle?: string | null, webuiDescription?: string | null) => void;
logout: () => void;
setVersion: (coreVersion: string | null, apiVersion: string | null) => void;
setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void;
setTokenRenewal: (renewalTime: number, expiresAt: number) => void; // Track token renewal
}
const useBackendStateStoreBase = create<BackendState>()((set, get) => ({
@@ -156,7 +159,19 @@ const useBackendState = createSelectors(useBackendStateStoreBase)
export { useBackendState }
const parseTokenPayload = (token: string): { sub?: string; role?: string } => {
// Format timestamp to human-readable local time with timezone
const formatTimestampToLocalString = (timestamp: number): string => {
const date = new Date(timestamp);
// Use Swedish locale 'sv-SE' to get YYYY-MM-DD HH:mm:ss format
const localTime = date.toLocaleString('sv-SE', { hour12: false });
// Get timezone offset
const offsetMinutes = -date.getTimezoneOffset();
const offsetHours = Math.floor(Math.abs(offsetMinutes) / 60);
const offsetSign = offsetMinutes >= 0 ? '+' : '-';
return `${localTime} (UTC${offsetSign}${offsetHours})`;
};
const parseTokenPayload = (token: string): { sub?: string; role?: string; exp?: number } => {
try {
// JWT tokens are in the format: header.payload.signature
const parts = token.split('.');
@@ -179,13 +194,20 @@ const isGuestToken = (token: string): boolean => {
return payload.role === 'guest';
};
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null } => {
const getTokenExpiresAt = (token: string): number | null => {
const payload = parseTokenPayload(token);
return payload.exp ? payload.exp * 1000 : null; // Convert to milliseconds
};
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null; lastTokenRenewal: string | null; tokenExpiresAt: number | null } => {
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
const lastTokenRenewal = localStorage.getItem('LIGHTRAG-LAST-TOKEN-RENEWAL');
const username = token ? getUsernameFromToken(token) : null;
const tokenExpiresAt = token ? getTokenExpiresAt(token) : null;
if (!token) {
return {
@@ -196,6 +218,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
username: null,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
lastTokenRenewal: null,
tokenExpiresAt: null,
};
}
@@ -207,6 +231,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
username: username,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
lastTokenRenewal: lastTokenRenewal,
tokenExpiresAt: tokenExpiresAt,
};
};
@@ -222,6 +248,8 @@ export const useAuthStore = create<AuthState>(set => {
username: initialState.username,
webuiTitle: initialState.webuiTitle,
webuiDescription: initialState.webuiDescription,
lastTokenRenewal: initialState.lastTokenRenewal,
tokenExpiresAt: initialState.tokenExpiresAt,
login: (token, isGuest = false, coreVersion = null, apiVersion = null, webuiTitle = null, webuiDescription = null) => {
localStorage.setItem('LIGHTRAG-API-TOKEN', token);
@@ -246,6 +274,13 @@ export const useAuthStore = create<AuthState>(set => {
}
const username = getUsernameFromToken(token);
const tokenExpiresAt = getTokenExpiresAt(token);
const now = Date.now();
const formattedTime = formatTimestampToLocalString(now);
// Initialize token issuance time with human-readable format
localStorage.setItem('LIGHTRAG-LAST-TOKEN-RENEWAL', formattedTime);
set({
isAuthenticated: true,
isGuestMode: isGuest,
@@ -254,11 +289,14 @@ export const useAuthStore = create<AuthState>(set => {
apiVersion: apiVersion,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
tokenExpiresAt: tokenExpiresAt,
lastTokenRenewal: formattedTime,
});
},
logout: () => {
localStorage.removeItem('LIGHTRAG-API-TOKEN');
localStorage.removeItem('LIGHTRAG-LAST-TOKEN-RENEWAL');
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
@@ -273,6 +311,8 @@ export const useAuthStore = create<AuthState>(set => {
apiVersion: apiVersion,
webuiTitle: webuiTitle,
webuiDescription: webuiDescription,
lastTokenRenewal: null,
tokenExpiresAt: null,
});
},
@@ -311,6 +351,19 @@ export const useAuthStore = create<AuthState>(set => {
webuiTitle: webuiTitle,
webuiDescription: webuiDescription
});
},
setTokenRenewal: (renewalTime, expiresAt) => {
const formattedTime = formatTimestampToLocalString(renewalTime);
// Update localStorage with human-readable format
localStorage.setItem('LIGHTRAG-LAST-TOKEN-RENEWAL', formattedTime);
// Update state
set({
lastTokenRenewal: formattedTime,
tokenExpiresAt: expiresAt
});
}
};
});

View File

@@ -0,0 +1,408 @@
"""
Pytest unit tests for token auto-renewal functionality
Tests:
1. Backend token renewal logic
2. Rate limiting for token renewals
3. Token renewal state tracking
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock
from fastapi import Response
import time
import sys
# Mock the config before importing utils_api
sys.modules["lightrag.api.config"] = Mock()
sys.modules["lightrag.api.auth"] = Mock()
# Create a simple token renewal cache for testing
_token_renewal_cache = {}
_RENEWAL_MIN_INTERVAL = 60
@pytest.mark.offline
class TestTokenRenewal:
"""Tests for token auto-renewal logic"""
@pytest.fixture
def mock_auth_handler(self):
"""Mock authentication handler"""
handler = Mock()
handler.guest_expire_hours = 24
handler.expire_hours = 24
handler.create_token = Mock(return_value="new-token-12345")
return handler
@pytest.fixture
def mock_global_args(self):
"""Mock global configuration"""
args = Mock()
args.token_auto_renew = True
args.token_renew_threshold = 0.5
return args
@pytest.fixture
def mock_token_info_guest(self):
"""Mock token info for guest user"""
# Token with 10 hours remaining (below 50% of 24 hours)
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
return {
"username": "guest",
"role": "guest",
"exp": exp_time,
"metadata": {"auth_mode": "disabled"},
}
@pytest.fixture
def mock_token_info_user(self):
"""Mock token info for regular user"""
# Token with 10 hours remaining (below 50% of 24 hours)
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
return {
"username": "testuser",
"role": "user",
"exp": exp_time,
"metadata": {"auth_mode": "enabled"},
}
@pytest.fixture
def mock_token_info_above_threshold(self):
"""Mock token info with time above renewal threshold"""
# Token with 20 hours remaining (above 50% of 24 hours)
exp_time = datetime.now(timezone.utc) + timedelta(hours=20)
return {
"username": "testuser",
"role": "user",
"exp": exp_time,
"metadata": {"auth_mode": "enabled"},
}
def test_token_renewal_when_below_threshold(
self, mock_auth_handler, mock_global_args, mock_token_info_user
):
"""Test that token is renewed when remaining time < threshold"""
# Use global cache
global _token_renewal_cache
# Clear cache
_token_renewal_cache.clear()
response = Mock(spec=Response)
response.headers = {}
# Simulate the renewal logic
expire_time = mock_token_info_user["exp"]
now = datetime.now(timezone.utc)
remaining_seconds = (expire_time - now).total_seconds()
role = mock_token_info_user["role"]
total_hours = (
mock_auth_handler.expire_hours
if role == "user"
else mock_auth_handler.guest_expire_hours
)
total_seconds = total_hours * 3600
# Should renew because remaining_seconds < total_seconds * 0.5
should_renew = (
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
)
assert should_renew is True
# Simulate renewal
username = mock_token_info_user["username"]
current_time = time.time()
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = current_time - last_renewal
# Should pass rate limit (first renewal)
assert time_since_last_renewal >= 60 or last_renewal == 0
# Perform renewal
new_token = mock_auth_handler.create_token(
username=username, role=role, metadata=mock_token_info_user["metadata"]
)
response.headers["X-New-Token"] = new_token
_token_renewal_cache[username] = current_time
# Verify
assert "X-New-Token" in response.headers
assert response.headers["X-New-Token"] == "new-token-12345"
assert username in _token_renewal_cache
def test_token_no_renewal_when_above_threshold(
self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold
):
"""Test that token is NOT renewed when remaining time > threshold"""
response = Mock(spec=Response)
response.headers = {}
expire_time = mock_token_info_above_threshold["exp"]
now = datetime.now(timezone.utc)
remaining_seconds = (expire_time - now).total_seconds()
mock_token_info_above_threshold["role"]
total_hours = mock_auth_handler.expire_hours
total_seconds = total_hours * 3600
# Should NOT renew because remaining_seconds > total_seconds * 0.5
should_renew = (
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
)
assert should_renew is False
# No renewal should happen
assert "X-New-Token" not in response.headers
def test_token_renewal_disabled(
self, mock_auth_handler, mock_global_args, mock_token_info_user
):
"""Test that no renewal happens when TOKEN_AUTO_RENEW=false"""
mock_global_args.token_auto_renew = False
response = Mock(spec=Response)
response.headers = {}
# Auto-renewal is disabled, so even if below threshold, no renewal
if not mock_global_args.token_auto_renew:
# Skip renewal logic
pass
assert "X-New-Token" not in response.headers
def test_token_renewal_for_guest_mode(
self, mock_auth_handler, mock_global_args, mock_token_info_guest
):
"""Test that guest tokens are renewed correctly"""
# Use global cache
global _token_renewal_cache
_token_renewal_cache.clear()
response = Mock(spec=Response)
response.headers = {}
expire_time = mock_token_info_guest["exp"]
now = datetime.now(timezone.utc)
remaining_seconds = (expire_time - now).total_seconds()
role = mock_token_info_guest["role"]
total_hours = mock_auth_handler.guest_expire_hours
total_seconds = total_hours * 3600
should_renew = (
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
)
assert should_renew is True
# Renewal for guest
username = mock_token_info_guest["username"]
new_token = mock_auth_handler.create_token(
username=username, role=role, metadata=mock_token_info_guest["metadata"]
)
response.headers["X-New-Token"] = new_token
_token_renewal_cache[username] = time.time()
assert "X-New-Token" in response.headers
assert username in _token_renewal_cache
@pytest.mark.offline
class TestRateLimiting:
"""Tests for token renewal rate limiting"""
@pytest.fixture
def mock_auth_handler(self):
"""Mock authentication handler"""
handler = Mock()
handler.expire_hours = 24
handler.create_token = Mock(return_value="new-token-12345")
return handler
def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler):
"""Test that second renewal within 60s is blocked"""
# Use global cache and constant
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
username = "testuser"
_token_renewal_cache.clear()
# First renewal
current_time_1 = time.time()
_token_renewal_cache[username] = current_time_1
response_1 = Mock(spec=Response)
response_1.headers = {}
response_1.headers["X-New-Token"] = "new-token-12345"
# Immediate second renewal attempt (within 60s)
current_time_2 = time.time() # Almost same time
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = current_time_2 - last_renewal
# Should be blocked by rate limit
assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL
response_2 = Mock(spec=Response)
response_2.headers = {}
# No new token should be issued
if time_since_last_renewal < _RENEWAL_MIN_INTERVAL:
# Rate limited, skip renewal
pass
assert "X-New-Token" not in response_2.headers
def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler):
"""Test that renewal succeeds after 60s interval"""
# Use global cache and constant
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
username = "testuser"
_token_renewal_cache.clear()
# First renewal at time T
first_renewal_time = time.time() - 61 # 61 seconds ago
_token_renewal_cache[username] = first_renewal_time
# Second renewal attempt now
current_time = time.time()
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = current_time - last_renewal
# Should pass rate limit (>60s elapsed)
assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL
response = Mock(spec=Response)
response.headers = {}
if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
new_token = mock_auth_handler.create_token(
username=username, role="user", metadata={}
)
response.headers["X-New-Token"] = new_token
_token_renewal_cache[username] = current_time
assert "X-New-Token" in response.headers
assert response.headers["X-New-Token"] == "new-token-12345"
def test_rate_limit_per_user(self, mock_auth_handler):
"""Test that different users have independent rate limits"""
# Use global cache
global _token_renewal_cache
_token_renewal_cache.clear()
user1 = "user1"
user2 = "user2"
current_time = time.time()
# User1 gets renewal
_token_renewal_cache[user1] = current_time
# User2 should still be able to get renewal (independent cache)
last_renewal_user2 = _token_renewal_cache.get(user2, 0)
assert last_renewal_user2 == 0 # No previous renewal
# User2 can renew
_token_renewal_cache[user2] = current_time
# Both users should have entries
assert user1 in _token_renewal_cache
assert user2 in _token_renewal_cache
assert _token_renewal_cache[user1] == _token_renewal_cache[user2]
@pytest.mark.offline
class TestTokenExpirationCalculation:
"""Tests for token expiration time calculation"""
def test_expiration_extraction_from_jwt(self):
"""Test extracting expiration time from JWT token"""
import base64
import json
# Create a mock JWT payload
exp_timestamp = int(
(datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()
)
payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp}
# Encode as base64 (simulating JWT structure: header.payload.signature)
payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode()
mock_token = f"header.{payload_b64}.signature"
# Simulate extraction
parts = mock_token.split(".")
assert len(parts) == 3
decoded_payload = json.loads(base64.b64decode(parts[1]))
assert decoded_payload["exp"] == exp_timestamp
assert decoded_payload["sub"] == "testuser"
def test_remaining_time_calculation(self):
"""Test calculation of remaining token time"""
# Token expires in 10 hours
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
now = datetime.now(timezone.utc)
remaining_seconds = (exp_time - now).total_seconds()
# Should be approximately 10 hours (36000 seconds)
assert 35990 < remaining_seconds < 36010
# Calculate percentage remaining (for 24-hour token)
total_seconds = 24 * 3600
percentage_remaining = remaining_seconds / total_seconds
# Should be approximately 41.67% remaining
assert 0.41 < percentage_remaining < 0.42
def test_threshold_comparison(self):
"""Test threshold-based renewal decision"""
threshold = 0.5
total_hours = 24
total_seconds = total_hours * 3600
# Scenario 1: 10 hours remaining -> should renew
remaining_seconds_1 = 10 * 3600
should_renew_1 = remaining_seconds_1 < total_seconds * threshold
assert should_renew_1 is True
# Scenario 2: 20 hours remaining -> should NOT renew
remaining_seconds_2 = 20 * 3600
should_renew_2 = remaining_seconds_2 < total_seconds * threshold
assert should_renew_2 is False
# Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew
remaining_seconds_3 = 12 * 3600
should_renew_3 = remaining_seconds_3 < total_seconds * threshold
assert should_renew_3 is False
@pytest.mark.offline
def test_renewal_cache_cleanup():
"""Test that renewal cache can be cleared"""
# Use global cache
global _token_renewal_cache
# Clear first
_token_renewal_cache.clear()
# Add some entries
_token_renewal_cache["user1"] = time.time()
_token_renewal_cache["user2"] = time.time()
assert len(_token_renewal_cache) == 2
# Clear cache
_token_renewal_cache.clear()
assert len(_token_renewal_cache) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])