* Add backend token renewal logic * Handle X-New-Token in frontend * Add rate limiting and config options * Implement silent refresh for guests * Add unit tests for renewal logic
409 lines
13 KiB
Python
409 lines
13 KiB
Python
"""
|
|
Pytest unit tests for token auto-renewal functionality
|
|
|
|
Tests:
|
|
1. Backend token renewal logic
|
|
2. Rate limiting for token renewals
|
|
3. Token renewal state tracking
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import Mock
|
|
from fastapi import Response
|
|
import time
|
|
import sys
|
|
|
|
# Mock the config before importing utils_api
|
|
sys.modules["lightrag.api.config"] = Mock()
|
|
sys.modules["lightrag.api.auth"] = Mock()
|
|
|
|
# Create a simple token renewal cache for testing
|
|
_token_renewal_cache = {}
|
|
_RENEWAL_MIN_INTERVAL = 60
|
|
|
|
|
|
@pytest.mark.offline
|
|
class TestTokenRenewal:
|
|
"""Tests for token auto-renewal logic"""
|
|
|
|
@pytest.fixture
|
|
def mock_auth_handler(self):
|
|
"""Mock authentication handler"""
|
|
handler = Mock()
|
|
handler.guest_expire_hours = 24
|
|
handler.expire_hours = 24
|
|
handler.create_token = Mock(return_value="new-token-12345")
|
|
return handler
|
|
|
|
@pytest.fixture
|
|
def mock_global_args(self):
|
|
"""Mock global configuration"""
|
|
args = Mock()
|
|
args.token_auto_renew = True
|
|
args.token_renew_threshold = 0.5
|
|
return args
|
|
|
|
@pytest.fixture
|
|
def mock_token_info_guest(self):
|
|
"""Mock token info for guest user"""
|
|
# Token with 10 hours remaining (below 50% of 24 hours)
|
|
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
|
|
return {
|
|
"username": "guest",
|
|
"role": "guest",
|
|
"exp": exp_time,
|
|
"metadata": {"auth_mode": "disabled"},
|
|
}
|
|
|
|
@pytest.fixture
|
|
def mock_token_info_user(self):
|
|
"""Mock token info for regular user"""
|
|
# Token with 10 hours remaining (below 50% of 24 hours)
|
|
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
|
|
return {
|
|
"username": "testuser",
|
|
"role": "user",
|
|
"exp": exp_time,
|
|
"metadata": {"auth_mode": "enabled"},
|
|
}
|
|
|
|
@pytest.fixture
|
|
def mock_token_info_above_threshold(self):
|
|
"""Mock token info with time above renewal threshold"""
|
|
# Token with 20 hours remaining (above 50% of 24 hours)
|
|
exp_time = datetime.now(timezone.utc) + timedelta(hours=20)
|
|
return {
|
|
"username": "testuser",
|
|
"role": "user",
|
|
"exp": exp_time,
|
|
"metadata": {"auth_mode": "enabled"},
|
|
}
|
|
|
|
def test_token_renewal_when_below_threshold(
|
|
self, mock_auth_handler, mock_global_args, mock_token_info_user
|
|
):
|
|
"""Test that token is renewed when remaining time < threshold"""
|
|
# Use global cache
|
|
global _token_renewal_cache
|
|
|
|
# Clear cache
|
|
_token_renewal_cache.clear()
|
|
|
|
response = Mock(spec=Response)
|
|
response.headers = {}
|
|
|
|
# Simulate the renewal logic
|
|
expire_time = mock_token_info_user["exp"]
|
|
now = datetime.now(timezone.utc)
|
|
remaining_seconds = (expire_time - now).total_seconds()
|
|
|
|
role = mock_token_info_user["role"]
|
|
total_hours = (
|
|
mock_auth_handler.expire_hours
|
|
if role == "user"
|
|
else mock_auth_handler.guest_expire_hours
|
|
)
|
|
total_seconds = total_hours * 3600
|
|
|
|
# Should renew because remaining_seconds < total_seconds * 0.5
|
|
should_renew = (
|
|
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
|
|
)
|
|
assert should_renew is True
|
|
|
|
# Simulate renewal
|
|
username = mock_token_info_user["username"]
|
|
current_time = time.time()
|
|
last_renewal = _token_renewal_cache.get(username, 0)
|
|
time_since_last_renewal = current_time - last_renewal
|
|
|
|
# Should pass rate limit (first renewal)
|
|
assert time_since_last_renewal >= 60 or last_renewal == 0
|
|
|
|
# Perform renewal
|
|
new_token = mock_auth_handler.create_token(
|
|
username=username, role=role, metadata=mock_token_info_user["metadata"]
|
|
)
|
|
response.headers["X-New-Token"] = new_token
|
|
_token_renewal_cache[username] = current_time
|
|
|
|
# Verify
|
|
assert "X-New-Token" in response.headers
|
|
assert response.headers["X-New-Token"] == "new-token-12345"
|
|
assert username in _token_renewal_cache
|
|
|
|
def test_token_no_renewal_when_above_threshold(
|
|
self, mock_auth_handler, mock_global_args, mock_token_info_above_threshold
|
|
):
|
|
"""Test that token is NOT renewed when remaining time > threshold"""
|
|
response = Mock(spec=Response)
|
|
response.headers = {}
|
|
|
|
expire_time = mock_token_info_above_threshold["exp"]
|
|
now = datetime.now(timezone.utc)
|
|
remaining_seconds = (expire_time - now).total_seconds()
|
|
|
|
mock_token_info_above_threshold["role"]
|
|
total_hours = mock_auth_handler.expire_hours
|
|
total_seconds = total_hours * 3600
|
|
|
|
# Should NOT renew because remaining_seconds > total_seconds * 0.5
|
|
should_renew = (
|
|
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
|
|
)
|
|
assert should_renew is False
|
|
|
|
# No renewal should happen
|
|
assert "X-New-Token" not in response.headers
|
|
|
|
def test_token_renewal_disabled(
|
|
self, mock_auth_handler, mock_global_args, mock_token_info_user
|
|
):
|
|
"""Test that no renewal happens when TOKEN_AUTO_RENEW=false"""
|
|
mock_global_args.token_auto_renew = False
|
|
response = Mock(spec=Response)
|
|
response.headers = {}
|
|
|
|
# Auto-renewal is disabled, so even if below threshold, no renewal
|
|
if not mock_global_args.token_auto_renew:
|
|
# Skip renewal logic
|
|
pass
|
|
|
|
assert "X-New-Token" not in response.headers
|
|
|
|
def test_token_renewal_for_guest_mode(
|
|
self, mock_auth_handler, mock_global_args, mock_token_info_guest
|
|
):
|
|
"""Test that guest tokens are renewed correctly"""
|
|
# Use global cache
|
|
global _token_renewal_cache
|
|
|
|
_token_renewal_cache.clear()
|
|
|
|
response = Mock(spec=Response)
|
|
response.headers = {}
|
|
|
|
expire_time = mock_token_info_guest["exp"]
|
|
now = datetime.now(timezone.utc)
|
|
remaining_seconds = (expire_time - now).total_seconds()
|
|
|
|
role = mock_token_info_guest["role"]
|
|
total_hours = mock_auth_handler.guest_expire_hours
|
|
total_seconds = total_hours * 3600
|
|
|
|
should_renew = (
|
|
remaining_seconds < total_seconds * mock_global_args.token_renew_threshold
|
|
)
|
|
assert should_renew is True
|
|
|
|
# Renewal for guest
|
|
username = mock_token_info_guest["username"]
|
|
new_token = mock_auth_handler.create_token(
|
|
username=username, role=role, metadata=mock_token_info_guest["metadata"]
|
|
)
|
|
response.headers["X-New-Token"] = new_token
|
|
_token_renewal_cache[username] = time.time()
|
|
|
|
assert "X-New-Token" in response.headers
|
|
assert username in _token_renewal_cache
|
|
|
|
|
|
@pytest.mark.offline
|
|
class TestRateLimiting:
|
|
"""Tests for token renewal rate limiting"""
|
|
|
|
@pytest.fixture
|
|
def mock_auth_handler(self):
|
|
"""Mock authentication handler"""
|
|
handler = Mock()
|
|
handler.expire_hours = 24
|
|
handler.create_token = Mock(return_value="new-token-12345")
|
|
return handler
|
|
|
|
def test_rate_limit_prevents_rapid_renewals(self, mock_auth_handler):
|
|
"""Test that second renewal within 60s is blocked"""
|
|
# Use global cache and constant
|
|
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
|
|
|
|
username = "testuser"
|
|
_token_renewal_cache.clear()
|
|
|
|
# First renewal
|
|
current_time_1 = time.time()
|
|
_token_renewal_cache[username] = current_time_1
|
|
|
|
response_1 = Mock(spec=Response)
|
|
response_1.headers = {}
|
|
response_1.headers["X-New-Token"] = "new-token-12345"
|
|
|
|
# Immediate second renewal attempt (within 60s)
|
|
current_time_2 = time.time() # Almost same time
|
|
last_renewal = _token_renewal_cache.get(username, 0)
|
|
time_since_last_renewal = current_time_2 - last_renewal
|
|
|
|
# Should be blocked by rate limit
|
|
assert time_since_last_renewal < _RENEWAL_MIN_INTERVAL
|
|
|
|
response_2 = Mock(spec=Response)
|
|
response_2.headers = {}
|
|
|
|
# No new token should be issued
|
|
if time_since_last_renewal < _RENEWAL_MIN_INTERVAL:
|
|
# Rate limited, skip renewal
|
|
pass
|
|
|
|
assert "X-New-Token" not in response_2.headers
|
|
|
|
def test_rate_limit_allows_renewal_after_interval(self, mock_auth_handler):
|
|
"""Test that renewal succeeds after 60s interval"""
|
|
# Use global cache and constant
|
|
global _token_renewal_cache, _RENEWAL_MIN_INTERVAL
|
|
|
|
username = "testuser"
|
|
_token_renewal_cache.clear()
|
|
|
|
# First renewal at time T
|
|
first_renewal_time = time.time() - 61 # 61 seconds ago
|
|
_token_renewal_cache[username] = first_renewal_time
|
|
|
|
# Second renewal attempt now
|
|
current_time = time.time()
|
|
last_renewal = _token_renewal_cache.get(username, 0)
|
|
time_since_last_renewal = current_time - last_renewal
|
|
|
|
# Should pass rate limit (>60s elapsed)
|
|
assert time_since_last_renewal >= _RENEWAL_MIN_INTERVAL
|
|
|
|
response = Mock(spec=Response)
|
|
response.headers = {}
|
|
|
|
if time_since_last_renewal >= _RENEWAL_MIN_INTERVAL:
|
|
new_token = mock_auth_handler.create_token(
|
|
username=username, role="user", metadata={}
|
|
)
|
|
response.headers["X-New-Token"] = new_token
|
|
_token_renewal_cache[username] = current_time
|
|
|
|
assert "X-New-Token" in response.headers
|
|
assert response.headers["X-New-Token"] == "new-token-12345"
|
|
|
|
def test_rate_limit_per_user(self, mock_auth_handler):
|
|
"""Test that different users have independent rate limits"""
|
|
# Use global cache
|
|
global _token_renewal_cache
|
|
|
|
_token_renewal_cache.clear()
|
|
|
|
user1 = "user1"
|
|
user2 = "user2"
|
|
|
|
current_time = time.time()
|
|
|
|
# User1 gets renewal
|
|
_token_renewal_cache[user1] = current_time
|
|
|
|
# User2 should still be able to get renewal (independent cache)
|
|
last_renewal_user2 = _token_renewal_cache.get(user2, 0)
|
|
assert last_renewal_user2 == 0 # No previous renewal
|
|
|
|
# User2 can renew
|
|
_token_renewal_cache[user2] = current_time
|
|
|
|
# Both users should have entries
|
|
assert user1 in _token_renewal_cache
|
|
assert user2 in _token_renewal_cache
|
|
assert _token_renewal_cache[user1] == _token_renewal_cache[user2]
|
|
|
|
|
|
@pytest.mark.offline
|
|
class TestTokenExpirationCalculation:
|
|
"""Tests for token expiration time calculation"""
|
|
|
|
def test_expiration_extraction_from_jwt(self):
|
|
"""Test extracting expiration time from JWT token"""
|
|
import base64
|
|
import json
|
|
|
|
# Create a mock JWT payload
|
|
exp_timestamp = int(
|
|
(datetime.now(timezone.utc) + timedelta(hours=24)).timestamp()
|
|
)
|
|
payload = {"sub": "testuser", "role": "user", "exp": exp_timestamp}
|
|
|
|
# Encode as base64 (simulating JWT structure: header.payload.signature)
|
|
payload_b64 = base64.b64encode(json.dumps(payload).encode()).decode()
|
|
mock_token = f"header.{payload_b64}.signature"
|
|
|
|
# Simulate extraction
|
|
parts = mock_token.split(".")
|
|
assert len(parts) == 3
|
|
|
|
decoded_payload = json.loads(base64.b64decode(parts[1]))
|
|
assert decoded_payload["exp"] == exp_timestamp
|
|
assert decoded_payload["sub"] == "testuser"
|
|
|
|
def test_remaining_time_calculation(self):
|
|
"""Test calculation of remaining token time"""
|
|
# Token expires in 10 hours
|
|
exp_time = datetime.now(timezone.utc) + timedelta(hours=10)
|
|
now = datetime.now(timezone.utc)
|
|
|
|
remaining_seconds = (exp_time - now).total_seconds()
|
|
|
|
# Should be approximately 10 hours (36000 seconds)
|
|
assert 35990 < remaining_seconds < 36010
|
|
|
|
# Calculate percentage remaining (for 24-hour token)
|
|
total_seconds = 24 * 3600
|
|
percentage_remaining = remaining_seconds / total_seconds
|
|
|
|
# Should be approximately 41.67% remaining
|
|
assert 0.41 < percentage_remaining < 0.42
|
|
|
|
def test_threshold_comparison(self):
|
|
"""Test threshold-based renewal decision"""
|
|
threshold = 0.5
|
|
total_hours = 24
|
|
total_seconds = total_hours * 3600
|
|
|
|
# Scenario 1: 10 hours remaining -> should renew
|
|
remaining_seconds_1 = 10 * 3600
|
|
should_renew_1 = remaining_seconds_1 < total_seconds * threshold
|
|
assert should_renew_1 is True
|
|
|
|
# Scenario 2: 20 hours remaining -> should NOT renew
|
|
remaining_seconds_2 = 20 * 3600
|
|
should_renew_2 = remaining_seconds_2 < total_seconds * threshold
|
|
assert should_renew_2 is False
|
|
|
|
# Scenario 3: Exactly 12 hours remaining (at threshold) -> should NOT renew
|
|
remaining_seconds_3 = 12 * 3600
|
|
should_renew_3 = remaining_seconds_3 < total_seconds * threshold
|
|
assert should_renew_3 is False
|
|
|
|
|
|
@pytest.mark.offline
|
|
def test_renewal_cache_cleanup():
|
|
"""Test that renewal cache can be cleared"""
|
|
# Use global cache
|
|
global _token_renewal_cache
|
|
|
|
# Clear first
|
|
_token_renewal_cache.clear()
|
|
|
|
# Add some entries
|
|
_token_renewal_cache["user1"] = time.time()
|
|
_token_renewal_cache["user2"] = time.time()
|
|
|
|
assert len(_token_renewal_cache) == 2
|
|
|
|
# Clear cache
|
|
_token_renewal_cache.clear()
|
|
|
|
assert len(_token_renewal_cache) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|