Files
lightrag/tests/test_token_auto_renewal.py
yangdx 01aaded80c Implement token auto-renewal and sliding window expiration mechanism
* Add backend token renewal logic
* Handle X-New-Token in frontend
* Add rate limiting and config options
* Implement silent refresh for guests
* Add unit tests for renewal logic
2025-12-26 11:31:48 +08:00

409 lines
13 KiB
Python

"""
Pytest unit tests for token auto-renewal functionality
Tests:
1. Backend token renewal logic
2. Rate limiting for token renewals
3. Token renewal state tracking
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock
from fastapi import Response
import time
import sys
# Mock the config before importing utils_api
sys.modules["lightrag.api.config"] = Mock()
sys.modules["lightrag.api.auth"] = Mock()
# Create a simple token renewal cache for testing
_token_renewal_cache = {}
_RENEWAL_MIN_INTERVAL = 60
@pytest.mark.offline
class TestTokenRenewal:
"""Tests for token auto-renewal logic"""
@pytest.fixture
def mock_auth_handler(self):
"""Mock authentication handler"""
handler = Mock()
handler.guest_expire_hours = 24
handler.expire_hours = 24
handler.create_token = Mock(return_value="new-token-12345")
return handler
@pytest.fixture
def mock_global_args(self):
"""Mock global configuration"""
args = Mock()
args.token_auto_renew = True
args.token_renew_threshold = 0.5
return args
@pytest.fixture
def mock_token_info_guest(self):
"""Mock token info for guest user"""
# Token with 10 hours remaining (below 50% of 24 hours)
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
return {
"username": "guest",
"role": "guest",
"exp": exp_time,
"metadata": {"auth_mode": "disabled"},
}
@pytest.fixture
def mock_token_info_user(self):
"""Mock token info for regular user"""
# Token with 10 hours remaining (below 50% of 24 hours)
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
return {
"username": "testuser",
"role": "user",
"exp": exp_time,
"metadata": {"auth_mode": "enabled"},
}
@pytest.fixture
def mock_token_info_above_threshold(self):
"""Mock token info with time above renewal threshold"""
# Token with 20 hours remaining (above 50% of 24 hours)
exp_time = datetime.now(timezone.utc) + timedelta(hours=20)
return {
"username": "testuser",
"role": "user",
"exp": exp_time,
"metadata": {"auth_mode": "enabled"},
}
def test_token_renewal_when_below_threshold(
self, mock_auth_handler, mock_global_args, mock_token_info_user
):
"""Test that token is renewed when remaining time < threshold"""
# Use global cache
global _token_renewal_cache
# Clear cache
_token_renewal_cache.clear()
response = Mock(spec=Response)
response.headers = {}
# Simulate the renewal logic
expire_time = mock_token_info_user["exp"]
now = datetime.now(timezone.utc)
remaining_seconds = (expire_time - now).total_seconds()
role = mock_token_info_user["role"]
total_hours = (
mock_auth_handler.expire_hours
if role == "user"
else mock_auth_handler.guest_expire_hours
)
total_seconds = total_hours * 3600
# Should renew because remaining_seconds < total_seconds * 0.5
should_renew = (
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
)
assert should_renew is True
# Simulate renewal
username = mock_token_info_user["username"]
current_time = time.time()
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = current_time - last_renewal
# Should pass rate limit (first renewal)
assert time_since_last_renewal >= 60 or last_renewal == 0
# Perform renewal
new_token = mock_auth_handler.create_token(
username=username, role=role, metadata=mock_token_info_user["metadata"]
)
response.headers["X-New-Token"] = new_token
_token_renewal_cache[username] = current_time
# Verify
assert "X-New-Token" in response.headers
assert response.headers["X-New-Token"] == "new-token-12345"
assert username in _token_renewal_cache
def test_token_no_renewal_when_above_threshold(
self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold
):
"""Test that token is NOT renewed when remaining time > threshold"""
response = Mock(spec=Response)
response.headers = {}
expire_time = mock_token_info_above_threshold["exp"]
now = datetime.now(timezone.utc)
remaining_seconds = (expire_time - now).total_seconds()
mock_token_info_above_threshold["role"]
total_hours = mock_auth_handler.expire_hours
total_seconds = total_hours * 3600
# Should NOT renew because remaining_seconds > total_seconds * 0.5
should_renew = (
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
)
assert should_renew is False
# No renewal should happen
assert "X-New-Token" not in response.headers
def test_token_renewal_disabled(
self, mock_auth_handler, mock_global_args, mock_token_info_user
):
"""Test that no renewal happens when TOKEN_AUTO_RENEW=false"""
mock_global_args.token_auto_renew = False
response = Mock(spec=Response)
response.headers = {}
# Auto-renewal is disabled, so even if below threshold, no renewal
if not mock_global_args.token_auto_renew:
# Skip renewal logic
pass
assert "X-New-Token" not in response.headers
def test_token_renewal_for_guest_mode(
self, mock_auth_handler, mock_global_args, mock_token_info_guest
):
"""Test that guest tokens are renewed correctly"""
# Use global cache
global _token_renewal_cache
_token_renewal_cache.clear()
response = Mock(spec=Response)
response.headers = {}
expire_time = mock_token_info_guest["exp"]
now = datetime.now(timezone.utc)
remaining_seconds = (expire_time - now).total_seconds()
role = mock_token_info_guest["role"]
total_hours = mock_auth_handler.guest_expire_hours
total_seconds = total_hours * 3600
should_renew = (
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
)
assert should_renew is True
# Renewal for guest
username = mock_token_info_guest["username"]
new_token = mock_auth_handler.create_token(
username=username, role=role, metadata=mock_token_info_guest["metadata"]
)
response.headers["X-New-Token"] = new_token
_token_renewal_cache[username] = time.time()
assert "X-New-Token" in response.headers
assert username in _token_renewal_cache
@pytest.mark.offline
class TestRateLimiting:
"""Tests for token renewal rate limiting"""
@pytest.fixture
def mock_auth_handler(self):
"""Mock authentication handler"""
handler = Mock()
handler.expire_hours = 24
handler.create_token = Mock(return_value="new-token-12345")
return handler
def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler):
"""Test that second renewal within 60s is blocked"""
# Use global cache and constant
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
username = "testuser"
_token_renewal_cache.clear()
# First renewal
current_time_1 = time.time()
_token_renewal_cache[username] = current_time_1
response_1 = Mock(spec=Response)
response_1.headers = {}
response_1.headers["X-New-Token"] = "new-token-12345"
# Immediate second renewal attempt (within 60s)
current_time_2 = time.time() # Almost same time
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = current_time_2 - last_renewal
# Should be blocked by rate limit
assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL
response_2 = Mock(spec=Response)
response_2.headers = {}
# No new token should be issued
if time_since_last_renewal < _RENEWAL_MIN_INTERVAL:
# Rate limited, skip renewal
pass
assert "X-New-Token" not in response_2.headers
def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler):
"""Test that renewal succeeds after 60s interval"""
# Use global cache and constant
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
username = "testuser"
_token_renewal_cache.clear()
# First renewal at time T
first_renewal_time = time.time() - 61 # 61 seconds ago
_token_renewal_cache[username] = first_renewal_time
# Second renewal attempt now
current_time = time.time()
last_renewal = _token_renewal_cache.get(username, 0)
time_since_last_renewal = current_time - last_renewal
# Should pass rate limit (>60s elapsed)
assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL
response = Mock(spec=Response)
response.headers = {}
if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
new_token = mock_auth_handler.create_token(
username=username, role="user", metadata={}
)
response.headers["X-New-Token"] = new_token
_token_renewal_cache[username] = current_time
assert "X-New-Token" in response.headers
assert response.headers["X-New-Token"] == "new-token-12345"
def test_rate_limit_per_user(self, mock_auth_handler):
"""Test that different users have independent rate limits"""
# Use global cache
global _token_renewal_cache
_token_renewal_cache.clear()
user1 = "user1"
user2 = "user2"
current_time = time.time()
# User1 gets renewal
_token_renewal_cache[user1] = current_time
# User2 should still be able to get renewal (independent cache)
last_renewal_user2 = _token_renewal_cache.get(user2, 0)
assert last_renewal_user2 == 0 # No previous renewal
# User2 can renew
_token_renewal_cache[user2] = current_time
# Both users should have entries
assert user1 in _token_renewal_cache
assert user2 in _token_renewal_cache
assert _token_renewal_cache[user1] == _token_renewal_cache[user2]
@pytest.mark.offline
class TestTokenExpirationCalculation:
"""Tests for token expiration time calculation"""
def test_expiration_extraction_from_jwt(self):
"""Test extracting expiration time from JWT token"""
import base64
import json
# Create a mock JWT payload
exp_timestamp = int(
(datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()
)
payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp}
# Encode as base64 (simulating JWT structure: header.payload.signature)
payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode()
mock_token = f"header.{payload_b64}.signature"
# Simulate extraction
parts = mock_token.split(".")
assert len(parts) == 3
decoded_payload = json.loads(base64.b64decode(parts[1]))
assert decoded_payload["exp"] == exp_timestamp
assert decoded_payload["sub"] == "testuser"
def test_remaining_time_calculation(self):
"""Test calculation of remaining token time"""
# Token expires in 10 hours
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
now = datetime.now(timezone.utc)
remaining_seconds = (exp_time - now).total_seconds()
# Should be approximately 10 hours (36000 seconds)
assert 35990 < remaining_seconds < 36010
# Calculate percentage remaining (for 24-hour token)
total_seconds = 24 * 3600
percentage_remaining = remaining_seconds / total_seconds
# Should be approximately 41.67% remaining
assert 0.41 < percentage_remaining < 0.42
def test_threshold_comparison(self):
"""Test threshold-based renewal decision"""
threshold = 0.5
total_hours = 24
total_seconds = total_hours * 3600
# Scenario 1: 10 hours remaining -> should renew
remaining_seconds_1 = 10 * 3600
should_renew_1 = remaining_seconds_1 < total_seconds * threshold
assert should_renew_1 is True
# Scenario 2: 20 hours remaining -> should NOT renew
remaining_seconds_2 = 20 * 3600
should_renew_2 = remaining_seconds_2 < total_seconds * threshold
assert should_renew_2 is False
# Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew
remaining_seconds_3 = 12 * 3600
should_renew_3 = remaining_seconds_3 < total_seconds * threshold
assert should_renew_3 is False
@pytest.mark.offline
def test_renewal_cache_cleanup():
"""Test that renewal cache can be cleared"""
# Use global cache
global _token_renewal_cache
# Clear first
_token_renewal_cache.clear()
# Add some entries
_token_renewal_cache["user1"] = time.time()
_token_renewal_cache["user2"] = time.time()
assert len(_token_renewal_cache) == 2
# Clear cache
_token_renewal_cache.clear()
assert len(_token_renewal_cache) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])