Implement token auto-renewal and sliding window expiration mechanism
* Add backend token renewal logic * Handle X-New-Token in frontend * Add rate limiting and config options * Implement silent refresh for guests * Add unit tests for renewal logic
This commit is contained in:
16
env.example
16
env.example
@@ -49,6 +49,22 @@ OLLAMA_EMULATING_MODEL_TAG=latest
|
||||
# GUEST_TOKEN_EXPIRE_HOURS=24
|
||||
# JWT_ALGORITHM=HS256
|
||||
|
||||
### Token Auto-Renewal Configuration (Sliding Window Expiration)
|
||||
### Enable automatic token renewal to prevent active users from being logged out
|
||||
### When enabled, tokens will be automatically renewed when remaining time < threshold
|
||||
# TOKEN_AUTO_RENEW=true
|
||||
### Token renewal threshold (0.0 - 1.0)
|
||||
### Renew token when remaining time < (total time * threshold)
|
||||
### Default: 0.5 (renew when 50% time remaining)
|
||||
### Examples:
|
||||
### 0.5 = renew when 24h token has 12h left
|
||||
### 0.25 = renew when 24h token has 6h left
|
||||
# TOKEN_RENEW_THRESHOLD=0.5
|
||||
### Note: Token renewal is automatically skipped for certain endpoints:
|
||||
### - /health: Health check endpoint (no authentication required)
|
||||
### - /documents/paginated: Frequently polled by client (renewal not needed)
|
||||
### - Rate limit: Minimum 60 seconds between renewals for same user
|
||||
|
||||
### API-Key to access LightRAG Server API
|
||||
### Use this key in HTTP requests with the 'X-API-Key' header
|
||||
### Example: curl -H "X-API-Key: your-secure-api-key-here" http://localhost:9621/query
|
||||
|
||||
@@ -403,10 +403,14 @@ def parse_args() -> argparse.Namespace:
|
||||
# For JWT Auth
|
||||
args.auth_accounts = get_env_value("AUTH_ACCOUNTS", "")
|
||||
args.token_secret = get_env_value("TOKEN_SECRET", "lightrag-jwt-default-secret")
|
||||
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, int)
|
||||
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
|
||||
args.token_expire_hours = get_env_value("TOKEN_EXPIRE_HOURS", 48, float)
|
||||
args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, float)
|
||||
args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
|
||||
|
||||
# Token auto-renewal configuration (sliding window expiration)
|
||||
args.token_auto_renew = get_env_value("TOKEN_AUTO_RENEW", True, bool)
|
||||
args.token_renew_threshold = get_env_value("TOKEN_RENEW_THRESHOLD", 0.5, float)
|
||||
|
||||
# Rerank model configuration
|
||||
args.rerank_model = get_env_value("RERANK_MODEL", None)
|
||||
args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
|
||||
|
||||
@@ -451,6 +451,9 @@ def create_app(args):
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=[
|
||||
"X-New-Token"
|
||||
], # Expose token renewal header for cross-origin requests
|
||||
)
|
||||
|
||||
# Create combined auth dependency for all endpoints
|
||||
|
||||
@@ -6,18 +6,37 @@ import os
|
||||
import argparse
|
||||
from typing import Optional, List, Tuple
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from ascii_colors import ASCIIColors
|
||||
from lightrag.api import __api_version__ as api_version
|
||||
from lightrag import __version__ as core_version
|
||||
from lightrag.constants import (
|
||||
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
|
||||
)
|
||||
from fastapi import HTTPException, Security, Request, status
|
||||
from fastapi import HTTPException, Security, Request, Response, status
|
||||
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from .auth import auth_handler
|
||||
from .config import ollama_server_infos, global_args, get_env_value
|
||||
|
||||
logger = logging.getLogger("lightrag")
|
||||
|
||||
# ========== Token Renewal Rate Limiting ==========
|
||||
# Cache to track last renewal time per user (username as key)
|
||||
# Format: {username: last_renewal_timestamp}
|
||||
_token_renewal_cache: dict[str, float] = {}
|
||||
_RENEWAL_MIN_INTERVAL = 60 # Minimum 60 seconds between renewals for same user
|
||||
|
||||
# ========== Token Renewal Path Exclusions ==========
|
||||
# Paths that should NOT trigger token auto-renewal
|
||||
# - /health: Health check endpoint, no login required
|
||||
# - /documents/paginated: Client polls this frequently, renewal not needed
|
||||
_TOKEN_RENEWAL_SKIP_PATHS = [
|
||||
"/health",
|
||||
"/documents/paginated",
|
||||
]
|
||||
|
||||
|
||||
def check_env_file():
|
||||
"""
|
||||
@@ -87,6 +106,7 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
|
||||
async def combined_dependency(
|
||||
request: Request,
|
||||
response: Response, # Added: needed to return new token via response header
|
||||
token: str = Security(oauth2_scheme),
|
||||
api_key_header_value: Optional[str] = None
|
||||
if api_key_header is None
|
||||
@@ -104,6 +124,80 @@ def get_combined_auth_dependency(api_key: Optional[str] = None):
|
||||
if token:
|
||||
try:
|
||||
token_info = auth_handler.validate_token(token)
|
||||
|
||||
# ========== Token Auto-Renewal Logic ==========
|
||||
from lightrag.api.config import global_args
|
||||
from datetime import datetime
|
||||
|
||||
if global_args.token_auto_renew:
|
||||
# Check if current path should skip token renewal
|
||||
skip_renewal = any(
|
||||
path == skip_path or path.startswith(skip_path + "/")
|
||||
for skip_path in _TOKEN_RENEWAL_SKIP_PATHS
|
||||
)
|
||||
|
||||
if skip_renewal:
|
||||
logger.debug(f"Token auto-renewal skipped for path: {path}")
|
||||
else:
|
||||
try:
|
||||
expire_time = token_info.get("exp")
|
||||
if expire_time:
|
||||
# Calculate remaining time ratio
|
||||
now = datetime.utcnow()
|
||||
remaining_seconds = (expire_time - now).total_seconds()
|
||||
|
||||
# Get original token expiration duration
|
||||
role = token_info.get("role", "user")
|
||||
total_hours = (
|
||||
auth_handler.guest_expire_hours
|
||||
if role == "guest"
|
||||
else auth_handler.expire_hours
|
||||
)
|
||||
total_seconds = total_hours * 3600
|
||||
|
||||
# Issue new token if remaining time < threshold
|
||||
if (
|
||||
remaining_seconds
|
||||
< total_seconds * global_args.token_renew_threshold
|
||||
):
|
||||
# ========== Rate Limiting Check ==========
|
||||
username = token_info["username"]
|
||||
current_time = time.time()
|
||||
last_renewal = _token_renewal_cache.get(username, 0)
|
||||
time_since_last_renewal = (
|
||||
current_time - last_renewal
|
||||
)
|
||||
|
||||
# Only renew if enough time has passed since last renewal
|
||||
if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
|
||||
new_token = auth_handler.create_token(
|
||||
username=username,
|
||||
role=role,
|
||||
metadata=token_info.get("metadata", {}),
|
||||
)
|
||||
# Return new token via response header
|
||||
response.headers["X-New-Token"] = new_token
|
||||
|
||||
# Update renewal cache
|
||||
_token_renewal_cache[username] = current_time
|
||||
|
||||
# Optional: log renewal
|
||||
logger.info(
|
||||
f"Token auto-renewed for user {username} "
|
||||
f"(role: {role}, remaining: {remaining_seconds:.0f}s)"
|
||||
)
|
||||
else:
|
||||
# Log skip due to rate limit
|
||||
logger.debug(
|
||||
f"Token renewal skipped for {username} "
|
||||
f"(rate limit: last renewal {time_since_last_renewal:.0f}s ago)"
|
||||
)
|
||||
# ========== End of Rate Limiting Check ==========
|
||||
except Exception as e:
|
||||
# Renewal failure should not affect normal request, just log
|
||||
logger.warning(f"Token auto-renew failed: {e}")
|
||||
# ========== End of Token Auto-Renewal Logic ==========
|
||||
|
||||
# Accept guest token if no auth is configured
|
||||
if not auth_configured and token_info.get("role") == "guest":
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@ import axios, { AxiosError } from 'axios'
|
||||
import { backendBaseUrl, popularLabelsDefaultLimit, searchLabelsDefaultLimit } from '@/lib/constants'
|
||||
import { errorMessage } from '@/lib/utils'
|
||||
import { useSettingsStore } from '@/stores/settings'
|
||||
import { useAuthStore } from '@/stores/state'
|
||||
import { navigationService } from '@/services/navigation'
|
||||
|
||||
// Types
|
||||
@@ -285,8 +286,55 @@ const axiosInstance = axios.create({
|
||||
}
|
||||
})
|
||||
|
||||
// ========== Token Management ==========
|
||||
// Prevent multiple requests from triggering token refresh simultaneously
|
||||
let isRefreshingGuestToken = false;
|
||||
let refreshTokenPromise: Promise<string> | null = null;
|
||||
|
||||
// Silent refresh for guest token
|
||||
const silentRefreshGuestToken = async (): Promise<string> => {
|
||||
// If already refreshing, return the same Promise
|
||||
if (isRefreshingGuestToken && refreshTokenPromise) {
|
||||
return refreshTokenPromise;
|
||||
}
|
||||
|
||||
isRefreshingGuestToken = true;
|
||||
refreshTokenPromise = (async () => {
|
||||
try {
|
||||
// Call /auth-status to get new guest token
|
||||
const response = await axios.get('/auth-status', {
|
||||
baseURL: backendBaseUrl,
|
||||
// This request must skip the interceptor to avoid adding expired token
|
||||
headers: { 'X-Skip-Interceptor': 'true' }
|
||||
});
|
||||
|
||||
if (response.data.access_token && !response.data.auth_configured) {
|
||||
const newToken = response.data.access_token;
|
||||
// Update localStorage
|
||||
localStorage.setItem('LIGHTRAG-API-TOKEN', newToken);
|
||||
// Update auth state
|
||||
useAuthStore.getState().login(newToken, true, response.data);
|
||||
return newToken;
|
||||
} else {
|
||||
throw new Error('Failed to get guest token');
|
||||
}
|
||||
} finally {
|
||||
isRefreshingGuestToken = false;
|
||||
refreshTokenPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
return refreshTokenPromise;
|
||||
};
|
||||
|
||||
// Interceptor: add api key and check authentication
|
||||
axiosInstance.interceptors.request.use((config) => {
|
||||
// Skip interceptor for token refresh requests
|
||||
if (config.headers['X-Skip-Interceptor']) {
|
||||
delete config.headers['X-Skip-Interceptor'];
|
||||
return config;
|
||||
}
|
||||
|
||||
const apiKey = useSettingsStore.getState().apiKey
|
||||
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
|
||||
@@ -300,20 +348,88 @@ axiosInstance.interceptors.request.use((config) => {
|
||||
return config
|
||||
})
|
||||
|
||||
// Interceptor:hanle error
|
||||
// Interceptor:handle token renewal and authentication errors
|
||||
axiosInstance.interceptors.response.use(
|
||||
(response) => response,
|
||||
(error: AxiosError) => {
|
||||
(response) => {
|
||||
// ========== Check for new token from backend ==========
|
||||
const newToken = response.headers['x-new-token'];
|
||||
if (newToken) {
|
||||
localStorage.setItem('LIGHTRAG-API-TOKEN', newToken);
|
||||
|
||||
// Optional: log in development mode
|
||||
if (import.meta.env.DEV) {
|
||||
console.log('[Auth] Token auto-renewed by backend');
|
||||
}
|
||||
|
||||
// Update auth state with renewal tracking
|
||||
try {
|
||||
const payload = JSON.parse(atob(newToken.split('.')[1]));
|
||||
const authStore = useAuthStore.getState();
|
||||
if (authStore.isAuthenticated) {
|
||||
// Track token renewal time and expiration
|
||||
const renewalTime = Date.now();
|
||||
const expiresAt = payload.exp ? payload.exp * 1000 : 0;
|
||||
authStore.setTokenRenewal(renewalTime, expiresAt);
|
||||
|
||||
// Update username (usually unchanged, but just in case)
|
||||
const newUsername = payload.sub;
|
||||
if (newUsername && newUsername !== authStore.username) {
|
||||
// Need to add setUsername method or just update via login
|
||||
// For now, we'll skip username update as it's rare
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('[Auth] Failed to parse renewed token:', error);
|
||||
}
|
||||
}
|
||||
// ========== End of token renewal check ==========
|
||||
|
||||
return response;
|
||||
},
|
||||
async (error: AxiosError) => {
|
||||
if (error.response) {
|
||||
if (error.response?.status === 401) {
|
||||
// For login API, throw error directly
|
||||
if (error.config?.url?.includes('/login')) {
|
||||
const originalRequest = error.config;
|
||||
|
||||
// 1. For login API, throw error directly
|
||||
if (originalRequest?.url?.includes('/login')) {
|
||||
throw error;
|
||||
}
|
||||
// For other APIs, navigate to login page
|
||||
navigationService.navigateToLogin();
|
||||
|
||||
// return a reject Promise
|
||||
// 2. Prevent infinite retry
|
||||
if (originalRequest && (originalRequest as any)._retry) {
|
||||
navigationService.navigateToLogin();
|
||||
return Promise.reject(new Error('Authentication required'));
|
||||
}
|
||||
|
||||
// 3. Check if in guest mode
|
||||
const authStore = useAuthStore.getState();
|
||||
const currentToken = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
const isGuest = currentToken && authStore.isGuestMode;
|
||||
|
||||
// 4. Guest mode: silent refresh and retry
|
||||
if (isGuest && originalRequest) {
|
||||
try {
|
||||
const newToken = await silentRefreshGuestToken();
|
||||
|
||||
// Mark as retried to prevent infinite loop
|
||||
(originalRequest as any)._retry = true;
|
||||
|
||||
// Update token in request headers
|
||||
originalRequest.headers['Authorization'] = `Bearer ${newToken}`;
|
||||
|
||||
// Retry original request
|
||||
return axiosInstance(originalRequest);
|
||||
} catch (refreshError) {
|
||||
console.error('Failed to refresh guest token:', refreshError);
|
||||
// Refresh failed, navigate to login
|
||||
navigationService.navigateToLogin();
|
||||
return Promise.reject(new Error('Failed to refresh authentication'));
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Non-guest mode: navigate to login page
|
||||
navigationService.navigateToLogin();
|
||||
return Promise.reject(new Error('Authentication required'));
|
||||
}
|
||||
throw new Error(
|
||||
@@ -418,7 +534,88 @@ export const queryTextStream = async (
|
||||
if (!response.ok) {
|
||||
// Handle 401 Unauthorized error specifically
|
||||
if (response.status === 401) {
|
||||
// For consistency with axios interceptor, navigate to login page
|
||||
// Check if in guest mode
|
||||
const authStore = useAuthStore.getState();
|
||||
const currentToken = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
const isGuest = currentToken && authStore.isGuestMode;
|
||||
|
||||
if (isGuest) {
|
||||
try {
|
||||
// Silent refresh token for guest mode
|
||||
const newToken = await silentRefreshGuestToken();
|
||||
|
||||
// Retry stream request with new token
|
||||
const retryHeaders = { ...headers };
|
||||
retryHeaders['Authorization'] = `Bearer ${newToken}`;
|
||||
|
||||
const retryResponse = await fetch(`${backendBaseUrl}/query/stream`, {
|
||||
method: 'POST',
|
||||
headers: retryHeaders,
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!retryResponse.ok) {
|
||||
throw new Error(`HTTP error! status: ${retryResponse.status}`);
|
||||
}
|
||||
|
||||
// Retry successful, process stream response
|
||||
// Re-execute the stream processing logic with retryResponse
|
||||
if (!retryResponse.body) {
|
||||
throw new Error('Response body is null');
|
||||
}
|
||||
|
||||
const reader = retryResponse.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split('\n');
|
||||
buffer = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(line);
|
||||
if (parsed.response) {
|
||||
onChunk(parsed.response);
|
||||
} else if (parsed.error) {
|
||||
onError?.(parsed.error);
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse JSON:', parseError, 'Line:', line);
|
||||
onError?.(`JSON parse error: ${parseError}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process any remaining data in buffer
|
||||
if (buffer.trim()) {
|
||||
try {
|
||||
const parsed = JSON.parse(buffer);
|
||||
if (parsed.response) {
|
||||
onChunk(parsed.response);
|
||||
} else if (parsed.error) {
|
||||
onError?.(parsed.error);
|
||||
}
|
||||
} catch (parseError) {
|
||||
console.error('Failed to parse final buffer:', parseError);
|
||||
}
|
||||
}
|
||||
|
||||
return; // Successfully completed retry
|
||||
} catch (refreshError) {
|
||||
console.error('Failed to refresh guest token for streaming:', refreshError);
|
||||
navigationService.navigateToLogin();
|
||||
throw new Error('Failed to refresh authentication');
|
||||
}
|
||||
}
|
||||
|
||||
// Non-guest mode: navigate to login page
|
||||
navigationService.navigateToLogin();
|
||||
|
||||
// Create a specific authentication error
|
||||
|
||||
@@ -33,11 +33,14 @@ interface AuthState {
|
||||
username: string | null; // login username
|
||||
webuiTitle: string | null; // Custom title
|
||||
webuiDescription: string | null; // Title description
|
||||
lastTokenRenewal: number | null; // Timestamp of last token renewal (for debugging and monitoring)
|
||||
tokenExpiresAt: number | null; // Token expiration timestamp (extracted from JWT)
|
||||
|
||||
login: (token: string, isGuest?: boolean, coreVersion?: string | null, apiVersion?: string | null, webuiTitle?: string | null, webuiDescription?: string | null) => void;
|
||||
logout: () => void;
|
||||
setVersion: (coreVersion: string | null, apiVersion: string | null) => void;
|
||||
setCustomTitle: (webuiTitle: string | null, webuiDescription: string | null) => void;
|
||||
setTokenRenewal: (renewalTime: number, expiresAt: number) => void; // Track token renewal
|
||||
}
|
||||
|
||||
const useBackendStateStoreBase = create<BackendState>()((set, get) => ({
|
||||
@@ -156,7 +159,7 @@ const useBackendState = createSelectors(useBackendStateStoreBase)
|
||||
|
||||
export { useBackendState }
|
||||
|
||||
const parseTokenPayload = (token: string): { sub?: string; role?: string } => {
|
||||
const parseTokenPayload = (token: string): { sub?: string; role?: string; exp?: number } => {
|
||||
try {
|
||||
// JWT tokens are in the format: header.payload.signature
|
||||
const parts = token.split('.');
|
||||
@@ -179,13 +182,20 @@ const isGuestToken = (token: string): boolean => {
|
||||
return payload.role === 'guest';
|
||||
};
|
||||
|
||||
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null } => {
|
||||
const getTokenExpiresAt = (token: string): number | null => {
|
||||
const payload = parseTokenPayload(token);
|
||||
return payload.exp ? payload.exp * 1000 : null; // Convert to milliseconds
|
||||
};
|
||||
|
||||
const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; coreVersion: string | null; apiVersion: string | null; username: string | null; webuiTitle: string | null; webuiDescription: string | null; lastTokenRenewal: number | null; tokenExpiresAt: number | null } => {
|
||||
const token = localStorage.getItem('LIGHTRAG-API-TOKEN');
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
|
||||
const webuiTitle = localStorage.getItem('LIGHTRAG-WEBUI-TITLE');
|
||||
const webuiDescription = localStorage.getItem('LIGHTRAG-WEBUI-DESCRIPTION');
|
||||
const lastTokenRenewal = localStorage.getItem('LIGHTRAG-LAST-TOKEN-RENEWAL');
|
||||
const username = token ? getUsernameFromToken(token) : null;
|
||||
const tokenExpiresAt = token ? getTokenExpiresAt(token) : null;
|
||||
|
||||
if (!token) {
|
||||
return {
|
||||
@@ -196,6 +206,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
|
||||
username: null,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
lastTokenRenewal: null,
|
||||
tokenExpiresAt: null,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -207,6 +219,8 @@ const initAuthState = (): { isAuthenticated: boolean; isGuestMode: boolean; core
|
||||
username: username,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
lastTokenRenewal: lastTokenRenewal ? parseInt(lastTokenRenewal) : null,
|
||||
tokenExpiresAt: tokenExpiresAt,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -222,6 +236,8 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
username: initialState.username,
|
||||
webuiTitle: initialState.webuiTitle,
|
||||
webuiDescription: initialState.webuiDescription,
|
||||
lastTokenRenewal: initialState.lastTokenRenewal,
|
||||
tokenExpiresAt: initialState.tokenExpiresAt,
|
||||
|
||||
login: (token, isGuest = false, coreVersion = null, apiVersion = null, webuiTitle = null, webuiDescription = null) => {
|
||||
localStorage.setItem('LIGHTRAG-API-TOKEN', token);
|
||||
@@ -246,6 +262,7 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
}
|
||||
|
||||
const username = getUsernameFromToken(token);
|
||||
const tokenExpiresAt = getTokenExpiresAt(token);
|
||||
set({
|
||||
isAuthenticated: true,
|
||||
isGuestMode: isGuest,
|
||||
@@ -254,11 +271,13 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
apiVersion: apiVersion,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
tokenExpiresAt: tokenExpiresAt,
|
||||
});
|
||||
},
|
||||
|
||||
logout: () => {
|
||||
localStorage.removeItem('LIGHTRAG-API-TOKEN');
|
||||
localStorage.removeItem('LIGHTRAG-LAST-TOKEN-RENEWAL');
|
||||
|
||||
const coreVersion = localStorage.getItem('LIGHTRAG-CORE-VERSION');
|
||||
const apiVersion = localStorage.getItem('LIGHTRAG-API-VERSION');
|
||||
@@ -273,6 +292,8 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
apiVersion: apiVersion,
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription,
|
||||
lastTokenRenewal: null,
|
||||
tokenExpiresAt: null,
|
||||
});
|
||||
},
|
||||
|
||||
@@ -311,6 +332,17 @@ export const useAuthStore = create<AuthState>(set => {
|
||||
webuiTitle: webuiTitle,
|
||||
webuiDescription: webuiDescription
|
||||
});
|
||||
},
|
||||
|
||||
setTokenRenewal: (renewalTime, expiresAt) => {
|
||||
// Update localStorage
|
||||
localStorage.setItem('LIGHTRAG-LAST-TOKEN-RENEWAL', renewalTime.toString());
|
||||
|
||||
// Update state
|
||||
set({
|
||||
lastTokenRenewal: renewalTime,
|
||||
tokenExpiresAt: expiresAt
|
||||
});
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
408
tests/test_token_auto_renewal.py
Normal file
408
tests/test_token_auto_renewal.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Pytest unit tests for token auto-renewal functionality
|
||||
|
||||
Tests:
|
||||
1. Backend token renewal logic
|
||||
2. Rate limiting for token renewals
|
||||
3. Token renewal state tracking
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import Mock
|
||||
from fastapi import Response
|
||||
import time
|
||||
import sys
|
||||
|
||||
# Mock the config before importing utils_api
|
||||
sys.modules["lightrag.api.config"] = Mock()
|
||||
sys.modules["lightrag.api.auth"] = Mock()
|
||||
|
||||
# Create a simple token renewal cache for testing
|
||||
_token_renewal_cache = {}
|
||||
_RENEWAL_MIN_INTERVAL = 60
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestTokenRenewal:
|
||||
"""Tests for token auto-renewal logic"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_handler(self):
|
||||
"""Mock authentication handler"""
|
||||
handler = Mock()
|
||||
handler.guest_expire_hours = 24
|
||||
handler.expire_hours = 24
|
||||
handler.create_token = Mock(return_value="new-token-12345")
|
||||
return handler
|
||||
|
||||
@pytest.fixture
|
||||
def mock_global_args(self):
|
||||
"""Mock global configuration"""
|
||||
args = Mock()
|
||||
args.token_auto_renew = True
|
||||
args.token_renew_threshold = 0.5
|
||||
return args
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_info_guest(self):
|
||||
"""Mock token info for guest user"""
|
||||
# Token with 10 hours remaining (below 50% of 24 hours)
|
||||
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
|
||||
return {
|
||||
"username": "guest",
|
||||
"role": "guest",
|
||||
"exp": exp_time,
|
||||
"metadata": {"auth_mode": "disabled"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_info_user(self):
|
||||
"""Mock token info for regular user"""
|
||||
# Token with 10 hours remaining (below 50% of 24 hours)
|
||||
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
|
||||
return {
|
||||
"username": "testuser",
|
||||
"role": "user",
|
||||
"exp": exp_time,
|
||||
"metadata": {"auth_mode": "enabled"},
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_info_above_threshold(self):
|
||||
"""Mock token info with time above renewal threshold"""
|
||||
# Token with 20 hours remaining (above 50% of 24 hours)
|
||||
exp_time = datetime.now(timezone.utc) + timedelta(hours=20)
|
||||
return {
|
||||
"username": "testuser",
|
||||
"role": "user",
|
||||
"exp": exp_time,
|
||||
"metadata": {"auth_mode": "enabled"},
|
||||
}
|
||||
|
||||
def test_token_renewal_when_below_threshold(
|
||||
self, mock_auth_handler, mock_global_args, mock_token_info_user
|
||||
):
|
||||
"""Test that token is renewed when remaining time < threshold"""
|
||||
# Use global cache
|
||||
global _token_renewal_cache
|
||||
|
||||
# Clear cache
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
|
||||
# Simulate the renewal logic
|
||||
expire_time = mock_token_info_user["exp"]
|
||||
now = datetime.now(timezone.utc)
|
||||
remaining_seconds = (expire_time - now).total_seconds()
|
||||
|
||||
role = mock_token_info_user["role"]
|
||||
total_hours = (
|
||||
mock_auth_handler.expire_hours
|
||||
if role == "user"
|
||||
else mock_auth_handler.guest_expire_hours
|
||||
)
|
||||
total_seconds = total_hours * 3600
|
||||
|
||||
# Should renew because remaining_seconds < total_seconds * 0.5
|
||||
should_renew = (
|
||||
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
|
||||
)
|
||||
assert should_renew is True
|
||||
|
||||
# Simulate renewal
|
||||
username = mock_token_info_user["username"]
|
||||
current_time = time.time()
|
||||
last_renewal = _token_renewal_cache.get(username, 0)
|
||||
time_since_last_renewal = current_time - last_renewal
|
||||
|
||||
# Should pass rate limit (first renewal)
|
||||
assert time_since_last_renewal >= 60 or last_renewal == 0
|
||||
|
||||
# Perform renewal
|
||||
new_token = mock_auth_handler.create_token(
|
||||
username=username, role=role, metadata=mock_token_info_user["metadata"]
|
||||
)
|
||||
response.headers["X-New-Token"] = new_token
|
||||
_token_renewal_cache[username] = current_time
|
||||
|
||||
# Verify
|
||||
assert "X-New-Token" in response.headers
|
||||
assert response.headers["X-New-Token"] == "new-token-12345"
|
||||
assert username in _token_renewal_cache
|
||||
|
||||
def test_token_no_renewal_when_above_threshold(
|
||||
self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold
|
||||
):
|
||||
"""Test that token is NOT renewed when remaining time > threshold"""
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
|
||||
expire_time = mock_token_info_above_threshold["exp"]
|
||||
now = datetime.now(timezone.utc)
|
||||
remaining_seconds = (expire_time - now).total_seconds()
|
||||
|
||||
mock_token_info_above_threshold["role"]
|
||||
total_hours = mock_auth_handler.expire_hours
|
||||
total_seconds = total_hours * 3600
|
||||
|
||||
# Should NOT renew because remaining_seconds > total_seconds * 0.5
|
||||
should_renew = (
|
||||
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
|
||||
)
|
||||
assert should_renew is False
|
||||
|
||||
# No renewal should happen
|
||||
assert "X-New-Token" not in response.headers
|
||||
|
||||
def test_token_renewal_disabled(
|
||||
self, mock_auth_handler, mock_global_args, mock_token_info_user
|
||||
):
|
||||
"""Test that no renewal happens when TOKEN_AUTO_RENEW=false"""
|
||||
mock_global_args.token_auto_renew = False
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
|
||||
# Auto-renewal is disabled, so even if below threshold, no renewal
|
||||
if not mock_global_args.token_auto_renew:
|
||||
# Skip renewal logic
|
||||
pass
|
||||
|
||||
assert "X-New-Token" not in response.headers
|
||||
|
||||
def test_token_renewal_for_guest_mode(
|
||||
self, mock_auth_handler, mock_global_args, mock_token_info_guest
|
||||
):
|
||||
"""Test that guest tokens are renewed correctly"""
|
||||
# Use global cache
|
||||
global _token_renewal_cache
|
||||
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
|
||||
expire_time = mock_token_info_guest["exp"]
|
||||
now = datetime.now(timezone.utc)
|
||||
remaining_seconds = (expire_time - now).total_seconds()
|
||||
|
||||
role = mock_token_info_guest["role"]
|
||||
total_hours = mock_auth_handler.guest_expire_hours
|
||||
total_seconds = total_hours * 3600
|
||||
|
||||
should_renew = (
|
||||
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
|
||||
)
|
||||
assert should_renew is True
|
||||
|
||||
# Renewal for guest
|
||||
username = mock_token_info_guest["username"]
|
||||
new_token = mock_auth_handler.create_token(
|
||||
username=username, role=role, metadata=mock_token_info_guest["metadata"]
|
||||
)
|
||||
response.headers["X-New-Token"] = new_token
|
||||
_token_renewal_cache[username] = time.time()
|
||||
|
||||
assert "X-New-Token" in response.headers
|
||||
assert username in _token_renewal_cache
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestRateLimiting:
|
||||
"""Tests for token renewal rate limiting"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_handler(self):
|
||||
"""Mock authentication handler"""
|
||||
handler = Mock()
|
||||
handler.expire_hours = 24
|
||||
handler.create_token = Mock(return_value="new-token-12345")
|
||||
return handler
|
||||
|
||||
def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler):
|
||||
"""Test that second renewal within 60s is blocked"""
|
||||
# Use global cache and constant
|
||||
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
|
||||
|
||||
username = "testuser"
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
# First renewal
|
||||
current_time_1 = time.time()
|
||||
_token_renewal_cache[username] = current_time_1
|
||||
|
||||
response_1 = Mock(spec=Response)
|
||||
response_1.headers = {}
|
||||
response_1.headers["X-New-Token"] = "new-token-12345"
|
||||
|
||||
# Immediate second renewal attempt (within 60s)
|
||||
current_time_2 = time.time() # Almost same time
|
||||
last_renewal = _token_renewal_cache.get(username, 0)
|
||||
time_since_last_renewal = current_time_2 - last_renewal
|
||||
|
||||
# Should be blocked by rate limit
|
||||
assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL
|
||||
|
||||
response_2 = Mock(spec=Response)
|
||||
response_2.headers = {}
|
||||
|
||||
# No new token should be issued
|
||||
if time_since_last_renewal < _RENEWAL_MIN_INTERVAL:
|
||||
# Rate limited, skip renewal
|
||||
pass
|
||||
|
||||
assert "X-New-Token" not in response_2.headers
|
||||
|
||||
def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler):
|
||||
"""Test that renewal succeeds after 60s interval"""
|
||||
# Use global cache and constant
|
||||
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
|
||||
|
||||
username = "testuser"
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
# First renewal at time T
|
||||
first_renewal_time = time.time() - 61 # 61 seconds ago
|
||||
_token_renewal_cache[username] = first_renewal_time
|
||||
|
||||
# Second renewal attempt now
|
||||
current_time = time.time()
|
||||
last_renewal = _token_renewal_cache.get(username, 0)
|
||||
time_since_last_renewal = current_time - last_renewal
|
||||
|
||||
# Should pass rate limit (>60s elapsed)
|
||||
assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL
|
||||
|
||||
response = Mock(spec=Response)
|
||||
response.headers = {}
|
||||
|
||||
if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
|
||||
new_token = mock_auth_handler.create_token(
|
||||
username=username, role="user", metadata={}
|
||||
)
|
||||
response.headers["X-New-Token"] = new_token
|
||||
_token_renewal_cache[username] = current_time
|
||||
|
||||
assert "X-New-Token" in response.headers
|
||||
assert response.headers["X-New-Token"] == "new-token-12345"
|
||||
|
||||
def test_rate_limit_per_user(self, mock_auth_handler):
|
||||
"""Test that different users have independent rate limits"""
|
||||
# Use global cache
|
||||
global _token_renewal_cache
|
||||
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
user1 = "user1"
|
||||
user2 = "user2"
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# User1 gets renewal
|
||||
_token_renewal_cache[user1] = current_time
|
||||
|
||||
# User2 should still be able to get renewal (independent cache)
|
||||
last_renewal_user2 = _token_renewal_cache.get(user2, 0)
|
||||
assert last_renewal_user2 == 0 # No previous renewal
|
||||
|
||||
# User2 can renew
|
||||
_token_renewal_cache[user2] = current_time
|
||||
|
||||
# Both users should have entries
|
||||
assert user1 in _token_renewal_cache
|
||||
assert user2 in _token_renewal_cache
|
||||
assert _token_renewal_cache[user1] == _token_renewal_cache[user2]
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
class TestTokenExpirationCalculation:
|
||||
"""Tests for token expiration time calculation"""
|
||||
|
||||
def test_expiration_extraction_from_jwt(self):
|
||||
"""Test extracting expiration time from JWT token"""
|
||||
import base64
|
||||
import json
|
||||
|
||||
# Create a mock JWT payload
|
||||
exp_timestamp = int(
|
||||
(datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()
|
||||
)
|
||||
payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp}
|
||||
|
||||
# Encode as base64 (simulating JWT structure: header.payload.signature)
|
||||
payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode()
|
||||
mock_token = f"header.{payload_b64}.signature"
|
||||
|
||||
# Simulate extraction
|
||||
parts = mock_token.split(".")
|
||||
assert len(parts) == 3
|
||||
|
||||
decoded_payload = json.loads(base64.b64decode(parts[1]))
|
||||
assert decoded_payload["exp"] == exp_timestamp
|
||||
assert decoded_payload["sub"] == "testuser"
|
||||
|
||||
def test_remaining_time_calculation(self):
|
||||
"""Test calculation of remaining token time"""
|
||||
# Token expires in 10 hours
|
||||
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
remaining_seconds = (exp_time - now).total_seconds()
|
||||
|
||||
# Should be approximately 10 hours (36000 seconds)
|
||||
assert 35990 < remaining_seconds < 36010
|
||||
|
||||
# Calculate percentage remaining (for 24-hour token)
|
||||
total_seconds = 24 * 3600
|
||||
percentage_remaining = remaining_seconds / total_seconds
|
||||
|
||||
# Should be approximately 41.67% remaining
|
||||
assert 0.41 < percentage_remaining < 0.42
|
||||
|
||||
def test_threshold_comparison(self):
|
||||
"""Test threshold-based renewal decision"""
|
||||
threshold = 0.5
|
||||
total_hours = 24
|
||||
total_seconds = total_hours * 3600
|
||||
|
||||
# Scenario 1: 10 hours remaining -> should renew
|
||||
remaining_seconds_1 = 10 * 3600
|
||||
should_renew_1 = remaining_seconds_1 < total_seconds * threshold
|
||||
assert should_renew_1 is True
|
||||
|
||||
# Scenario 2: 20 hours remaining -> should NOT renew
|
||||
remaining_seconds_2 = 20 * 3600
|
||||
should_renew_2 = remaining_seconds_2 < total_seconds * threshold
|
||||
assert should_renew_2 is False
|
||||
|
||||
# Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew
|
||||
remaining_seconds_3 = 12 * 3600
|
||||
should_renew_3 = remaining_seconds_3 < total_seconds * threshold
|
||||
assert should_renew_3 is False
|
||||
|
||||
|
||||
@pytest.mark.offline
|
||||
def test_renewal_cache_cleanup():
|
||||
"""Test that renewal cache can be cleared"""
|
||||
# Use global cache
|
||||
global _token_renewal_cache
|
||||
|
||||
# Clear first
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
# Add some entries
|
||||
_token_renewal_cache["user1"] = time.time()
|
||||
_token_renewal_cache["user2"] = time.time()
|
||||
|
||||
assert len(_token_renewal_cache) == 2
|
||||
|
||||
# Clear cache
|
||||
_token_renewal_cache.clear()
|
||||
|
||||
assert len(_token_renewal_cache) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user