diff --git a/env.example b/env.example index 15bcdfd3..894e7e85 100644 --- a/env.example +++ b/env.example @@ -49,6 +49,23 @@ OLLAMA_EMULATING_MODEL_TAG=latest # GUEST_TOKEN_EXPIRE_HOURS=24 # JWT_ALGORITHM=HS256 +### Token Auto-Renewal Configuration (Sliding Window Expiration) +### Enable automatic token renewal to prevent active users from being logged out +### When enabled, tokens will be automatically renewed when remaining time < threshold +# TOKEN_AUTO_RENEW=true +### Token renewal threshold (0.0 - 1.0) +### Renew token when remaining time < (total time * threshold) +### Default: 0.5 (renew when 50% time remaining) +### Examples: +### 0.5 = renew when 24h token has 12h left +### 0.25 = renew when 24h token has 6h left +# TOKEN_RENEW_THRESHOLD=0.5 +### Note: Token renewal is automatically skipped for certain endpoints: +### - /health: Health check endpoint (no authentication required) +### - /documents/paginated: Frequently polled by client (5-30s interval) +### - /documents/pipeline_status: Very frequently polled by client (2s interval) +### - Rate limit: Minimum 60 seconds between renewals for same user + ### API-Key to access LightRAG Server API ### Use this key in HTTP requests with the 'X-API-Key' header ### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query diff --git a/lightrag/api/config.py b/lightrag/api/config.py index 4d8ab1e1..4781d1cc 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -403,10 +403,14 @@ def parse_args() -> argparse.Namespace: # For JWT Auth args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "") args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret") - args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int) - args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int) + args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, float) + args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, float) args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256") + # Token auto-renewal configuration (sliding window expiration) + args.token_auto_renew = get_env_value("TOKEN_AUTO_RENEW", True, bool) + args.token_renew_threshold = get_env_value("TOKEN_RENEW_THRESHOLD", 0.5, float) + # Rerank model configuration args.rerank_model = get_env_value("RERANK_MODEL", None) args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None) diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index f128c2a8..137a5335 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -451,6 +451,9 @@ def create_app(args): allow_credentials=True, allow_methods=["*"], allow_headers=["*"], + expose_headers=[ + "X-New-Token" + ], # Expose token renewal header for cross-origin requests ) # Create combined auth dependency for all endpoints diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index 47751772..17922b28 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -6,18 +6,39 @@ import os import argparse from typing import Optional, List, Tuple import sys +import time +import logging from ascii_colors import ASCIIColors from lightrag.api import __api_version__ as api_version from lightrag import __version__ as core_version from lightrag.constants import ( DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE, ) -from fastapi import HTTPException, Security, Request, status +from fastapi import HTTPException, Security, Request, Response, status from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from starlette.status import HTTP_403_FORBIDDEN from .auth import auth_handler from .config import ollama_server_infos, global_args, get_env_value +logger = logging.getLogger("lightrag") + +# ========== Token Renewal Rate Limiting ========== +# Cache to track last renewal time per user (username as key) +# Format: {username: last_renewal_timestamp} +_token_renewal_cache: dict[str, float] = {} +_RENEWAL_MIN_INTERVAL = 60 # Minimum 60 seconds between renewals for same user + +# ========== Token Renewal Path Exclusions ========== +# Paths that should NOT trigger token auto-renewal +# - /health: Health check endpoint, no login required +# - /documents/paginated: Client polls this frequently (5-30s), renewal not needed +# - /documents/pipeline_status: Client polls this very frequently (2s), renewal not needed +_TOKEN_RENEWAL_SKIP_PATHS = [ + "/health", + "/documents/paginated", + "/documents/pipeline_status", +] + def check_env_file(): """ @@ -87,6 +108,7 @@ def get_combined_auth_dependency(api_key: Optional[str] = None): async def combined_dependency( request: Request, + response: Response, # Added: needed to return new token via response header token: str = Security(oauth2_scheme), api_key_header_value: Optional[str] = None if api_key_header is None @@ -104,6 +126,80 @@ def get_combined_auth_dependency(api_key: Optional[str] = None): if token: try: token_info = auth_handler.validate_token(token) + + # ========== Token Auto-Renewal Logic ========== + from lightrag.api.config import global_args + from datetime import datetime + + if global_args.token_auto_renew: + # Check if current path should skip token renewal + skip_renewal = any( + path == skip_path or path.startswith(skip_path + "/") + for skip_path in _TOKEN_RENEWAL_SKIP_PATHS + ) + + if skip_renewal: + logger.debug(f"Token auto-renewal skipped for path: {path}") + else: + try: + expire_time = token_info.get("exp") + if expire_time: + # Calculate remaining time ratio + now = datetime.utcnow() + remaining_seconds = (expire_time - now).total_seconds() + + # Get original token expiration duration + role = token_info.get("role", "user") + total_hours = ( + auth_handler.guest_expire_hours + if role == "guest" + else auth_handler.expire_hours + ) + total_seconds = total_hours * 3600 + + # Issue new token if remaining time < threshold + if ( + remaining_seconds + < total_seconds * global_args.token_renew_threshold + ): + # ========== Rate Limiting Check ========== + username = token_info["username"] + current_time = time.time() + last_renewal = _token_renewal_cache.get(username, 0) + time_since_last_renewal = ( + current_time - last_renewal + ) + + # Only renew if enough time has passed since last renewal + if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL: + new_token = auth_handler.create_token( + username=username, + role=role, + metadata=token_info.get("metadata", {}), + ) + # Return new token via response header + response.headers["X-New-Token"] = new_token + + # Update renewal cache + _token_renewal_cache[username] = current_time + + # Optional: log renewal + logger.info( + f"Token auto-renewed for user {username} " + f"(role: {role}, remaining: {remaining_seconds:.0f}s)" + ) + else: + # Log skip due to rate limit + logger.debug( + f"Token renewal skipped for {username} " + f"(rate limit: last renewal {time_since_last_renewal:.0f}s ago)" + ) + # ========== End of Rate Limiting Check ========== + except Exception as e: + # Renewal failure should not affect normal request, just log + logger.warning(f"Token auto-renew failed: {e}") + # ========== End of Token Auto-Renewal Logic ========== + # Accept guest token if no auth is configured if not auth_configured and token_info.get("role") == "guest": return diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index 7cf1aec6..3cde0709 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -2,6 +2,7 @@ import axios, { AxiosError } from 'axios' import { backendBaseUrl, popularLabelsDefaultLimit, searchLabelsDefaultLimit } from '@/lib/constants' import { errorMessage } from '@/lib/utils' import { useSettingsStore } from '@/stores/settings' +import { useAuthStore } from '@/stores/state' import { navigationService } from '@/services/navigation' // Types @@ -285,8 +286,62 @@ const axiosInstance = axios.create({ } }) +// ========== Token Management ========== +// Prevent multiple requests from triggering token refresh simultaneously +let isRefreshingGuestToken = false; +let refreshTokenPromise: Promise | null = null; + +// Silent refresh for guest token +const silentRefreshGuestToken = async (): Promise => { + // If already refreshing, return the same Promise + if (isRefreshingGuestToken && refreshTokenPromise) { + return refreshTokenPromise; + } + + isRefreshingGuestToken = true; + refreshTokenPromise = (async () => { + try { + // Call /auth-status to get new guest token + const response = await axios.get('/auth-status', { + baseURL: backendBaseUrl, + // This request must skip the interceptor to avoid adding expired token + headers: { 'X-Skip-Interceptor': 'true' } + }); + + if (response.data.access_token && !response.data.auth_configured) { + const newToken = response.data.access_token; + // Update localStorage + localStorage.setItem('LIGHTRAG-API-TOKEN', newToken); + // Update auth state + useAuthStore.getState().login( + newToken, + true, + response.data.core_version, + response.data.api_version, + response.data.webui_title || null, + response.data.webui_description || null + ); + return newToken; + } else { + throw new Error('Failed to get guest token'); + } + } finally { + isRefreshingGuestToken = false; + refreshTokenPromise = null; + } + })(); + + return refreshTokenPromise; +}; + // Interceptor: add api key and check authentication axiosInstance.interceptors.request.use((config) => { + // Skip interceptor for token refresh requests + if (config.headers['X-Skip-Interceptor']) { + delete config.headers['X-Skip-Interceptor']; + return config; + } + const apiKey = useSettingsStore.getState().apiKey const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); @@ -300,20 +355,88 @@ axiosInstance.interceptors.request.use((config) => { return config }) -// Interceptor:hanle error +// Interceptor:handle token renewal and authentication errors axiosInstance.interceptors.response.use( - (response) => response, - (error: AxiosError) => { + (response) => { + // ========== Check for new token from backend ========== + const newToken = response.headers['x-new-token']; + if (newToken) { + localStorage.setItem('LIGHTRAG-API-TOKEN', newToken); + + // Optional: log in development mode + if (import.meta.env.DEV) { + console.log('[Auth] Token auto-renewed by backend'); + } + + // Update auth state with renewal tracking + try { + const payload = JSON.parse(atob(newToken.split('.')[1])); + const authStore = useAuthStore.getState(); + if (authStore.isAuthenticated) { + // Track token renewal time and expiration + const renewalTime = Date.now(); + const expiresAt = payload.exp ? payload.exp * 1000 : 0; + authStore.setTokenRenewal(renewalTime, expiresAt); + + // Update username (usually unchanged, but just in case) + const newUsername = payload.sub; + if (newUsername && newUsername !== authStore.username) { + // Need to add setUsername method or just update via login + // For now, we'll skip username update as it's rare + } + } + } catch (error) { + console.warn('[Auth] Failed to parse renewed token:', error); + } + } + // ========== End of token renewal check ========== + + return response; + }, + async (error: AxiosError) => { if (error.response) { if (error.response?.status === 401) { - // For login API, throw error directly - if (error.config?.url?.includes('/login')) { + const originalRequest = error.config; + + // 1. For login API, throw error directly + if (originalRequest?.url?.includes('/login')) { throw error; } - // For other APIs, navigate to login page - navigationService.navigateToLogin(); - // return a reject Promise + // 2. Prevent infinite retry + if (originalRequest && (originalRequest as any)._retry) { + navigationService.navigateToLogin(); + return Promise.reject(new Error('Authentication required')); + } + + // 3. Check if in guest mode + const authStore = useAuthStore.getState(); + const currentToken = localStorage.getItem('LIGHTRAG-API-TOKEN'); + const isGuest = currentToken && authStore.isGuestMode; + + // 4. Guest mode: silent refresh and retry + if (isGuest && originalRequest) { + try { + const newToken = await silentRefreshGuestToken(); + + // Mark as retried to prevent infinite loop + (originalRequest as any)._retry = true; + + // Update token in request headers + originalRequest.headers['Authorization'] = `Bearer ${newToken}`; + + // Retry original request + return axiosInstance(originalRequest); + } catch (refreshError) { + console.error('Failed to refresh guest token:', refreshError); + // Refresh failed, navigate to login + navigationService.navigateToLogin(); + return Promise.reject(new Error('Failed to refresh authentication')); + } + } + + // 5. Non-guest mode: navigate to login page + navigationService.navigateToLogin(); return Promise.reject(new Error('Authentication required')); } throw new Error( @@ -418,7 +541,88 @@ export const queryTextStream = async ( if (!response.ok) { // Handle 401 Unauthorized error specifically if (response.status === 401) { - // For consistency with axios interceptor, navigate to login page + // Check if in guest mode + const authStore = useAuthStore.getState(); + const currentToken = localStorage.getItem('LIGHTRAG-API-TOKEN'); + const isGuest = currentToken && authStore.isGuestMode; + + if (isGuest) { + try { + // Silent refresh token for guest mode + const newToken = await silentRefreshGuestToken(); + + // Retry stream request with new token + const retryHeaders = { ...headers }; + retryHeaders['Authorization'] = `Bearer ${newToken}`; + + const retryResponse = await fetch(`${backendBaseUrl}/query/stream`, { + method: 'POST', + headers: retryHeaders, + body: JSON.stringify(request), + }); + + if (!retryResponse.ok) { + throw new Error(`HTTP error! status: ${retryResponse.status}`); + } + + // Retry successful, process stream response + // Re-execute the stream processing logic with retryResponse + if (!retryResponse.body) { + throw new Error('Response body is null'); + } + + const reader = retryResponse.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.trim()) { + try { + const parsed = JSON.parse(line); + if (parsed.response) { + onChunk(parsed.response); + } else if (parsed.error) { + onError?.(parsed.error); + } + } catch (parseError) { + console.error('Failed to parse JSON:', parseError, 'Line:', line); + onError?.(`JSON parse error: ${parseError}`); + } + } + } + } + + // Process any remaining data in buffer + if (buffer.trim()) { + try { + const parsed = JSON.parse(buffer); + if (parsed.response) { + onChunk(parsed.response); + } else if (parsed.error) { + onError?.(parsed.error); + } + } catch (parseError) { + console.error('Failed to parse final buffer:', parseError); + } + } + + return; // Successfully completed retry + } catch (refreshError) { + console.error('Failed to refresh guest token for streaming:', refreshError); + navigationService.navigateToLogin(); + throw new Error('Failed to refresh authentication'); + } + } + + // Non-guest mode: navigate to login page navigationService.navigateToLogin(); // Create a specific authentication error diff --git a/lightrag_webui/src/stores/state.ts b/lightrag_webui/src/stores/state.ts index bbf155b1..3a076ad3 100644 --- a/lightrag_webui/src/stores/state.ts +++ b/lightrag_webui/src/stores/state.ts @@ -33,11 +33,14 @@ interface AuthState { username: string | null; // login username webuiTitle: string | null; // Custom title webuiDescription: string | null; // Title description + lastTokenRenewal: string | null; // Human-readable local time of last token renewal (for debugging and monitoring) + tokenExpiresAt: number | null; // Token expiration timestamp (extracted from JWT) login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null, webuiTitle?: string | null, webuiDescription?: string | null) => void; logout: () => void; setVersion: (coreVersion: string | null, apiVersion: string | null) => void; setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void; + setTokenRenewal: (renewalTime: number, expiresAt: number) => void; // Track token renewal } const useBackendStateStoreBase = create()((set, get) => ({ @@ -156,7 +159,19 @@ const useBackendState = createSelectors(useBackendStateStoreBase) export { useBackendState } -const parseTokenPayload = (token: string): { sub?: string; role?: string } => { +// Format timestamp to human-readable local time with timezone +const formatTimestampToLocalString = (timestamp: number): string => { + const date = new Date(timestamp); + // Use Swedish locale 'sv-SE' to get YYYY-MM-DD HH:mm:ss format + const localTime = date.toLocaleString('sv-SE', { hour12: false }); + // Get timezone offset + const offsetMinutes = -date.getTimezoneOffset(); + const offsetHours = Math.floor(Math.abs(offsetMinutes) / 60); + const offsetSign = offsetMinutes >= 0 ? '+' : '-'; + return `${localTime} (UTC${offsetSign}${offsetHours})`; +}; + +const parseTokenPayload = (token: string): { sub?: string; role?: string; exp?: number } => { try { // JWT tokens are in the format: header.payload.signature const parts = token.split('.'); @@ -179,13 +194,20 @@ const isGuestToken = (token: string): boolean => { return payload.role === 'guest'; }; -const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null } => { +const getTokenExpiresAt = (token: string): number | null => { + const payload = parseTokenPayload(token); + return payload.exp ? payload.exp * 1000 : null; // Convert to milliseconds +}; + +const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null; lastTokenRenewal: string | null; tokenExpiresAt: number | null } => { const token = localStorage.getItem('LIGHTRAG-API-TOKEN'); const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION'); const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION'); const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE'); const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION'); + const lastTokenRenewal = localStorage.getItem('LIGHTRAG-LAST-TOKEN-RENEWAL'); const username = token ? getUsernameFromToken(token) : null; + const tokenExpiresAt = token ? getTokenExpiresAt(token) : null; if (!token) { return { @@ -196,6 +218,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core username: null, webuiTitle: webuiTitle, webuiDescription: webuiDescription, + lastTokenRenewal: null, + tokenExpiresAt: null, }; } @@ -207,6 +231,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core username: username, webuiTitle: webuiTitle, webuiDescription: webuiDescription, + lastTokenRenewal: lastTokenRenewal, + tokenExpiresAt: tokenExpiresAt, }; }; @@ -222,6 +248,8 @@ export const useAuthStore = create(set => { username: initialState.username, webuiTitle: initialState.webuiTitle, webuiDescription: initialState.webuiDescription, + lastTokenRenewal: initialState.lastTokenRenewal, + tokenExpiresAt: initialState.tokenExpiresAt, login: (token, isGuest = false, coreVersion = null, apiVersion = null, webuiTitle = null, webuiDescription = null) => { localStorage.setItem('LIGHTRAG-API-TOKEN', token); @@ -246,6 +274,13 @@ export const useAuthStore = create(set => { } const username = getUsernameFromToken(token); + const tokenExpiresAt = getTokenExpiresAt(token); + const now = Date.now(); + const formattedTime = formatTimestampToLocalString(now); + + // Initialize token issuance time with human-readable format + localStorage.setItem('LIGHTRAG-LAST-TOKEN-RENEWAL', formattedTime); + set({ isAuthenticated: true, isGuestMode: isGuest, @@ -254,11 +289,14 @@ export const useAuthStore = create(set => { apiVersion: apiVersion, webuiTitle: webuiTitle, webuiDescription: webuiDescription, + tokenExpiresAt: tokenExpiresAt, + lastTokenRenewal: formattedTime, }); }, logout: () => { localStorage.removeItem('LIGHTRAG-API-TOKEN'); + localStorage.removeItem('LIGHTRAG-LAST-TOKEN-RENEWAL'); const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION'); const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION'); @@ -273,6 +311,8 @@ export const useAuthStore = create(set => { apiVersion: apiVersion, webuiTitle: webuiTitle, webuiDescription: webuiDescription, + lastTokenRenewal: null, + tokenExpiresAt: null, }); }, @@ -311,6 +351,19 @@ export const useAuthStore = create(set => { webuiTitle: webuiTitle, webuiDescription: webuiDescription }); + }, + + setTokenRenewal: (renewalTime, expiresAt) => { + const formattedTime = formatTimestampToLocalString(renewalTime); + + // Update localStorage with human-readable format + localStorage.setItem('LIGHTRAG-LAST-TOKEN-RENEWAL', formattedTime); + + // Update state + set({ + lastTokenRenewal: formattedTime, + tokenExpiresAt: expiresAt + }); } }; }); diff --git a/tests/test_token_auto_renewal.py b/tests/test_token_auto_renewal.py new file mode 100644 index 00000000..2cfcbd67 --- /dev/null +++ b/tests/test_token_auto_renewal.py @@ -0,0 +1,408 @@ +""" +Pytest unit tests for token auto-renewal functionality + +Tests: +1. Backend token renewal logic +2. Rate limiting for token renewals +3. Token renewal state tracking +""" + +import pytest +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock +from fastapi import Response +import time +import sys + +# Mock the config before importing utils_api +sys.modules["lightrag.api.config"] = Mock() +sys.modules["lightrag.api.auth"] = Mock() + +# Create a simple token renewal cache for testing +_token_renewal_cache = {} +_RENEWAL_MIN_INTERVAL = 60 + + +@pytest.mark.offline +class TestTokenRenewal: + """Tests for token auto-renewal logic""" + + @pytest.fixture + def mock_auth_handler(self): + """Mock authentication handler""" + handler = Mock() + handler.guest_expire_hours = 24 + handler.expire_hours = 24 + handler.create_token = Mock(return_value="new-token-12345") + return handler + + @pytest.fixture + def mock_global_args(self): + """Mock global configuration""" + args = Mock() + args.token_auto_renew = True + args.token_renew_threshold = 0.5 + return args + + @pytest.fixture + def mock_token_info_guest(self): + """Mock token info for guest user""" + # Token with 10 hours remaining (below 50% of 24 hours) + exp_time = datetime.now(timezone.utc) + timedelta(hours=10) + return { + "username": "guest", + "role": "guest", + "exp": exp_time, + "metadata": {"auth_mode": "disabled"}, + } + + @pytest.fixture + def mock_token_info_user(self): + """Mock token info for regular user""" + # Token with 10 hours remaining (below 50% of 24 hours) + exp_time = datetime.now(timezone.utc) + timedelta(hours=10) + return { + "username": "testuser", + "role": "user", + "exp": exp_time, + "metadata": {"auth_mode": "enabled"}, + } + + @pytest.fixture + def mock_token_info_above_threshold(self): + """Mock token info with time above renewal threshold""" + # Token with 20 hours remaining (above 50% of 24 hours) + exp_time = datetime.now(timezone.utc) + timedelta(hours=20) + return { + "username": "testuser", + "role": "user", + "exp": exp_time, + "metadata": {"auth_mode": "enabled"}, + } + + def test_token_renewal_when_below_threshold( + self, mock_auth_handler, mock_global_args, mock_token_info_user + ): + """Test that token is renewed when remaining time < threshold""" + # Use global cache + global _token_renewal_cache + + # Clear cache + _token_renewal_cache.clear() + + response = Mock(spec=Response) + response.headers = {} + + # Simulate the renewal logic + expire_time = mock_token_info_user["exp"] + now = datetime.now(timezone.utc) + remaining_seconds = (expire_time - now).total_seconds() + + role = mock_token_info_user["role"] + total_hours = ( + mock_auth_handler.expire_hours + if role == "user" + else mock_auth_handler.guest_expire_hours + ) + total_seconds = total_hours * 3600 + + # Should renew because remaining_seconds < total_seconds * 0.5 + should_renew = ( + remaining_seconds < total_seconds * mock_global_args.token_renew_threshold + ) + assert should_renew is True + + # Simulate renewal + username = mock_token_info_user["username"] + current_time = time.time() + last_renewal = _token_renewal_cache.get(username, 0) + time_since_last_renewal = current_time - last_renewal + + # Should pass rate limit (first renewal) + assert time_since_last_renewal >= 60 or last_renewal == 0 + + # Perform renewal + new_token = mock_auth_handler.create_token( + username=username, role=role, metadata=mock_token_info_user["metadata"] + ) + response.headers["X-New-Token"] = new_token + _token_renewal_cache[username] = current_time + + # Verify + assert "X-New-Token" in response.headers + assert response.headers["X-New-Token"] == "new-token-12345" + assert username in _token_renewal_cache + + def test_token_no_renewal_when_above_threshold( + self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold + ): + """Test that token is NOT renewed when remaining time > threshold""" + response = Mock(spec=Response) + response.headers = {} + + expire_time = mock_token_info_above_threshold["exp"] + now = datetime.now(timezone.utc) + remaining_seconds = (expire_time - now).total_seconds() + + mock_token_info_above_threshold["role"] + total_hours = mock_auth_handler.expire_hours + total_seconds = total_hours * 3600 + + # Should NOT renew because remaining_seconds > total_seconds * 0.5 + should_renew = ( + remaining_seconds < total_seconds * mock_global_args.token_renew_threshold + ) + assert should_renew is False + + # No renewal should happen + assert "X-New-Token" not in response.headers + + def test_token_renewal_disabled( + self, mock_auth_handler, mock_global_args, mock_token_info_user + ): + """Test that no renewal happens when TOKEN_AUTO_RENEW=false""" + mock_global_args.token_auto_renew = False + response = Mock(spec=Response) + response.headers = {} + + # Auto-renewal is disabled, so even if below threshold, no renewal + if not mock_global_args.token_auto_renew: + # Skip renewal logic + pass + + assert "X-New-Token" not in response.headers + + def test_token_renewal_for_guest_mode( + self, mock_auth_handler, mock_global_args, mock_token_info_guest + ): + """Test that guest tokens are renewed correctly""" + # Use global cache + global _token_renewal_cache + + _token_renewal_cache.clear() + + response = Mock(spec=Response) + response.headers = {} + + expire_time = mock_token_info_guest["exp"] + now = datetime.now(timezone.utc) + remaining_seconds = (expire_time - now).total_seconds() + + role = mock_token_info_guest["role"] + total_hours = mock_auth_handler.guest_expire_hours + total_seconds = total_hours * 3600 + + should_renew = ( + remaining_seconds < total_seconds * mock_global_args.token_renew_threshold + ) + assert should_renew is True + + # Renewal for guest + username = mock_token_info_guest["username"] + new_token = mock_auth_handler.create_token( + username=username, role=role, metadata=mock_token_info_guest["metadata"] + ) + response.headers["X-New-Token"] = new_token + _token_renewal_cache[username] = time.time() + + assert "X-New-Token" in response.headers + assert username in _token_renewal_cache + + +@pytest.mark.offline +class TestRateLimiting: + """Tests for token renewal rate limiting""" + + @pytest.fixture + def mock_auth_handler(self): + """Mock authentication handler""" + handler = Mock() + handler.expire_hours = 24 + handler.create_token = Mock(return_value="new-token-12345") + return handler + + def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler): + """Test that second renewal within 60s is blocked""" + # Use global cache and constant + global _token_renewal_cache, _RENEWAL_MIN_INTERVAL + + username = "testuser" + _token_renewal_cache.clear() + + # First renewal + current_time_1 = time.time() + _token_renewal_cache[username] = current_time_1 + + response_1 = Mock(spec=Response) + response_1.headers = {} + response_1.headers["X-New-Token"] = "new-token-12345" + + # Immediate second renewal attempt (within 60s) + current_time_2 = time.time() # Almost same time + last_renewal = _token_renewal_cache.get(username, 0) + time_since_last_renewal = current_time_2 - last_renewal + + # Should be blocked by rate limit + assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL + + response_2 = Mock(spec=Response) + response_2.headers = {} + + # No new token should be issued + if time_since_last_renewal < _RENEWAL_MIN_INTERVAL: + # Rate limited, skip renewal + pass + + assert "X-New-Token" not in response_2.headers + + def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler): + """Test that renewal succeeds after 60s interval""" + # Use global cache and constant + global _token_renewal_cache, _RENEWAL_MIN_INTERVAL + + username = "testuser" + _token_renewal_cache.clear() + + # First renewal at time T + first_renewal_time = time.time() - 61 # 61 seconds ago + _token_renewal_cache[username] = first_renewal_time + + # Second renewal attempt now + current_time = time.time() + last_renewal = _token_renewal_cache.get(username, 0) + time_since_last_renewal = current_time - last_renewal + + # Should pass rate limit (>60s elapsed) + assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL + + response = Mock(spec=Response) + response.headers = {} + + if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL: + new_token = mock_auth_handler.create_token( + username=username, role="user", metadata={} + ) + response.headers["X-New-Token"] = new_token + _token_renewal_cache[username] = current_time + + assert "X-New-Token" in response.headers + assert response.headers["X-New-Token"] == "new-token-12345" + + def test_rate_limit_per_user(self, mock_auth_handler): + """Test that different users have independent rate limits""" + # Use global cache + global _token_renewal_cache + + _token_renewal_cache.clear() + + user1 = "user1" + user2 = "user2" + + current_time = time.time() + + # User1 gets renewal + _token_renewal_cache[user1] = current_time + + # User2 should still be able to get renewal (independent cache) + last_renewal_user2 = _token_renewal_cache.get(user2, 0) + assert last_renewal_user2 == 0 # No previous renewal + + # User2 can renew + _token_renewal_cache[user2] = current_time + + # Both users should have entries + assert user1 in _token_renewal_cache + assert user2 in _token_renewal_cache + assert _token_renewal_cache[user1] == _token_renewal_cache[user2] + + +@pytest.mark.offline +class TestTokenExpirationCalculation: + """Tests for token expiration time calculation""" + + def test_expiration_extraction_from_jwt(self): + """Test extracting expiration time from JWT token""" + import base64 + import json + + # Create a mock JWT payload + exp_timestamp = int( + (datetime.now(timezone.utc) + timedelta(hours=24)).timestamp() + ) + payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp} + + # Encode as base64 (simulating JWT structure: header.payload.signature) + payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode() + mock_token = f"header.{payload_b64}.signature" + + # Simulate extraction + parts = mock_token.split(".") + assert len(parts) == 3 + + decoded_payload = json.loads(base64.b64decode(parts[1])) + assert decoded_payload["exp"] == exp_timestamp + assert decoded_payload["sub"] == "testuser" + + def test_remaining_time_calculation(self): + """Test calculation of remaining token time""" + # Token expires in 10 hours + exp_time = datetime.now(timezone.utc) + timedelta(hours=10) + now = datetime.now(timezone.utc) + + remaining_seconds = (exp_time - now).total_seconds() + + # Should be approximately 10 hours (36000 seconds) + assert 35990 < remaining_seconds < 36010 + + # Calculate percentage remaining (for 24-hour token) + total_seconds = 24 * 3600 + percentage_remaining = remaining_seconds / total_seconds + + # Should be approximately 41.67% remaining + assert 0.41 < percentage_remaining < 0.42 + + def test_threshold_comparison(self): + """Test threshold-based renewal decision""" + threshold = 0.5 + total_hours = 24 + total_seconds = total_hours * 3600 + + # Scenario 1: 10 hours remaining -> should renew + remaining_seconds_1 = 10 * 3600 + should_renew_1 = remaining_seconds_1 < total_seconds * threshold + assert should_renew_1 is True + + # Scenario 2: 20 hours remaining -> should NOT renew + remaining_seconds_2 = 20 * 3600 + should_renew_2 = remaining_seconds_2 < total_seconds * threshold + assert should_renew_2 is False + + # Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew + remaining_seconds_3 = 12 * 3600 + should_renew_3 = remaining_seconds_3 < total_seconds * threshold + assert should_renew_3 is False + + +@pytest.mark.offline +def test_renewal_cache_cleanup(): + """Test that renewal cache can be cleared""" + # Use global cache + global _token_renewal_cache + + # Clear first + _token_renewal_cache.clear() + + # Add some entries + _token_renewal_cache["user1"] = time.time() + _token_renewal_cache["user2"] = time.time() + + assert len(_token_renewal_cache) == 2 + + # Clear cache + _token_renewal_cache.clear() + + assert len(_token_renewal_cache) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"])