Move shared lock validation to factory functions and fix test formatting

- Enforce init in lock factory functions
- Simplify UnifiedLock class logic
- Update lock safety tests
- Fix line wrapping in test files
This commit is contained in:
yangdx
2025-12-19 15:58:02 +08:00
parent a3b33bbc3c
commit e9003f3f13
4 changed files with 114 additions and 151 deletions

View File

@@ -163,30 +163,19 @@ class UnifiedLock(Generic[T]):
enable_output=self._enable_logging,
)
# Then acquire the main lock
if self._lock is not None:
if self._is_async:
await self._lock.acquire()
else:
self._lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
# Acquire the main lock
# Note: self._lock should never be None here as the check has been moved
# to get_internal_lock() and get_data_init_lock() functions
if self._is_async:
await self._lock.acquire()
else:
# CRITICAL: Raise exception instead of allowing unprotected execution
error_msg = (
f"CRITICAL: Lock '{self._name}' is None - shared data not initialized. "
f"Call initialize_share_data() before using locks!"
)
direct_log(
f"== Lock == Process {self._pid}: {error_msg}",
level="ERROR",
enable_output=True,
)
raise RuntimeError(error_msg)
self._lock.acquire()
direct_log(
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
level="INFO",
enable_output=self._enable_logging,
)
return self
except Exception as e:
# If main lock acquisition fails, release the async lock if it was acquired
@@ -272,6 +261,10 @@ class UnifiedLock(Generic[T]):
try:
if self._is_async:
raise RuntimeError("Use 'async with' for shared_storage lock")
# Acquire the main lock
# Note: self._lock should never be None here as the check has been moved
# to get_internal_lock() and get_data_init_lock() functions
direct_log(
f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)",
level="DEBUG",
@@ -1077,6 +1070,10 @@ class _KeyedLockContext:
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified storage lock for data consistency"""
if _internal_lock is None:
raise RuntimeError(
"Shared data not initialized. Call initialize_share_data() before using locks!"
)
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_internal_lock,
@@ -1107,6 +1104,10 @@ def get_storage_keyed_lock(
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
"""return unified data initialization lock for ensuring atomic data initialization"""
if _data_init_lock is None:
raise RuntimeError(
"Shared data not initialized. Call initialize_share_data() before using locks!"
)
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
return UnifiedLock(
lock=_data_init_lock,

View File

@@ -316,7 +316,9 @@ class TestPostgresDimensionMismatch:
# Return different counts based on table name in query and migration state
if "LIGHTRAG_DOC_CHUNKS_model_1536d" in query:
# After migration: return migrated count, before: return 0
return {"count": len(mock_records) if migration_done["value"] else 0}
return {
"count": len(mock_records) if migration_done["value"] else 0
}
# Legacy table always has 2 records (matching mock_records)
return {"count": len(mock_records)}
elif "pg_attribute" in query:
@@ -360,16 +362,21 @@ class TestPostgresDimensionMismatch:
# Custom mock for _pg_migrate_workspace_data that updates migration_done
async def mock_migrate_func(*args, **kwargs):
migration_done["value"] = True # Set BEFORE returning so verification query sees it
migration_done["value"] = (
True # Set BEFORE returning so verification query sees it
)
return len(mock_records)
with patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=mock_table_exists,
), patch(
"lightrag.kg.postgres_impl._pg_migrate_workspace_data",
side_effect=mock_migrate_func,
) as mock_migrate:
with (
patch(
"lightrag.kg.postgres_impl._pg_table_exists",
side_effect=mock_table_exists,
),
patch(
"lightrag.kg.postgres_impl._pg_migrate_workspace_data",
side_effect=mock_migrate_func,
) as mock_migrate,
):
# Call setup_table with matching 1536d
await PGVectorStorage.setup_table(
db,

View File

@@ -1,13 +1,16 @@
"""
Tests for UnifiedLock safety when lock is None.
This test module verifies that UnifiedLock raises RuntimeError instead of
allowing unprotected execution when the underlying lock is None, preventing
false security and potential race conditions.
This test module verifies that get_internal_lock() and get_data_init_lock()
raise RuntimeError when shared data is not initialized, preventing false
security and potential race conditions.
Critical Bug 1: When self._lock is None, __aenter__ used to log WARNING but
still return successfully, allowing critical sections to run without lock
protection, causing race conditions and data corruption.
Design: The None check has been moved from UnifiedLock.__aenter__/__enter__
to the lock factory functions (get_internal_lock, get_data_init_lock) for
early failure detection.
Critical Bug 1 (Fixed): When self._lock is None, the code would fail with
AttributeError. Now the check is in factory functions for clearer errors.
Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
recovery logic would attempt to release it again, causing double-release issues.
@@ -15,81 +18,52 @@ recovery logic would attempt to release it again, causing double-release issues.
import pytest
from unittest.mock import MagicMock, AsyncMock
from lightrag.kg.shared_storage import UnifiedLock
from lightrag.kg.shared_storage import (
UnifiedLock,
get_internal_lock,
get_data_init_lock,
finalize_share_data,
)
class TestUnifiedLockSafety:
"""Test suite for UnifiedLock None safety checks."""
@pytest.mark.asyncio
async def test_unified_lock_raises_on_none_async(self):
def setup_method(self):
"""Ensure shared data is finalized before each test."""
finalize_share_data()
def teardown_method(self):
"""Clean up after each test."""
finalize_share_data()
def test_get_internal_lock_raises_when_not_initialized(self):
"""
Test that UnifiedLock raises RuntimeError when lock is None (async mode).
Test that get_internal_lock() raises RuntimeError when shared data is not initialized.
Scenario: Attempt to use UnifiedLock before initialize_share_data() is called.
Expected: RuntimeError raised, preventing unprotected critical section execution.
"""
lock = UnifiedLock(
lock=None, is_async=True, name="test_async_lock", enable_logging=False
)
with pytest.raises(
RuntimeError, match="shared data not initialized|Lock.*is None"
):
async with lock:
# This code should NEVER execute
pytest.fail(
"Code inside lock context should not execute when lock is None"
)
@pytest.mark.asyncio
async def test_unified_lock_raises_on_none_sync(self):
"""
Test that UnifiedLock raises RuntimeError when lock is None (sync mode).
Scenario: Attempt to use UnifiedLock with None lock in sync mode.
Scenario: Call get_internal_lock() before initialize_share_data() is called.
Expected: RuntimeError raised with clear error message.
"""
lock = UnifiedLock(
lock=None, is_async=False, name="test_sync_lock", enable_logging=False
)
This test verifies the None check has been moved to the factory function.
"""
with pytest.raises(
RuntimeError, match="shared data not initialized|Lock.*is None"
RuntimeError, match="Shared data not initialized.*initialize_share_data"
):
async with lock:
# This code should NEVER execute
pytest.fail(
"Code inside lock context should not execute when lock is None"
)
get_internal_lock()
@pytest.mark.asyncio
async def test_error_message_clarity(self):
def test_get_data_init_lock_raises_when_not_initialized(self):
"""
Test that the error message clearly indicates the problem and solution.
Test that get_data_init_lock() raises RuntimeError when shared data is not initialized.
Scenario: Lock is None and user tries to acquire it.
Expected: Error message mentions 'shared data not initialized' and
'initialize_share_data()'.
Scenario: Call get_data_init_lock() before initialize_share_data() is called.
Expected: RuntimeError raised with clear error message.
This test verifies the None check has been moved to the factory function.
"""
lock = UnifiedLock(
lock=None,
is_async=True,
name="test_error_message",
enable_logging=False,
)
with pytest.raises(RuntimeError) as exc_info:
async with lock:
pass
error_message = str(exc_info.value)
# Verify error message contains helpful information
assert (
"shared data not initialized" in error_message.lower()
or "lock" in error_message.lower()
)
assert "initialize_share_data" in error_message or "None" in error_message
with pytest.raises(
RuntimeError, match="Shared data not initialized.*initialize_share_data"
):
get_data_init_lock()
@pytest.mark.asyncio
async def test_aexit_no_double_release_on_async_lock_failure(self):
@@ -144,48 +118,3 @@ class TestUnifiedLockSafety:
# Main lock should have been released successfully
main_lock.release.assert_called_once()
@pytest.mark.asyncio
async def test_aexit_recovery_on_main_lock_failure(self):
"""
Test that __aexit__ recovery logic works when main lock release fails.
Scenario: main_lock.release() fails before async_lock is attempted.
Expected: Recovery logic should attempt to release async_lock to prevent
resource leaks.
This verifies the recovery logic still works correctly with async_lock_released tracking.
"""
# Create mock locks
main_lock = MagicMock()
main_lock.acquire = MagicMock()
# Make main_lock.release() fail
def mock_main_release_fail():
raise RuntimeError("Main lock release failed")
main_lock.release = MagicMock(side_effect=mock_main_release_fail)
async_lock = AsyncMock()
async_lock.acquire = AsyncMock()
async_lock.release = MagicMock()
# Create UnifiedLock with both locks (sync mode with async_lock)
lock = UnifiedLock(
lock=main_lock, is_async=False, name="test_recovery", enable_logging=False
)
lock._async_lock = async_lock
# Try to use the lock - should fail during __aexit__
try:
async with lock:
pass
except RuntimeError as e:
# Should get the main lock release error
assert "Main lock release failed" in str(e)
# Main lock release should have been attempted
main_lock.release.assert_called_once()
# Recovery logic should have attempted to release async_lock
async_lock.release.assert_called_once()

View File

@@ -48,7 +48,11 @@ class TestWorkspaceMigrationIsolation:
sql_lower = sql.lower()
# Count query for new table workspace data (verification before migration)
if "count(*)" in sql_lower and "model_1536d" in sql_lower and "where workspace" in sql_lower:
if (
"count(*)" in sql_lower
and "model_1536d" in sql_lower
and "where workspace" in sql_lower
):
return new_table_record_count # Initially 0
# Count query with workspace filter (legacy table) - for workspace count
@@ -60,11 +64,19 @@ class TestWorkspaceMigrationIsolation:
return {"count": 0}
# Count query for legacy table (total, no workspace filter)
elif "count(*)" in sql_lower and "lightrag" in sql_lower and "where workspace" not in sql_lower:
elif (
"count(*)" in sql_lower
and "lightrag" in sql_lower
and "where workspace" not in sql_lower
):
return {"count": 5} # Total records in legacy
# SELECT with workspace filter for migration
elif "select * from" in sql_lower and "where workspace" in sql_lower and multirows:
elif (
"select * from" in sql_lower
and "where workspace" in sql_lower
and multirows
):
workspace = params[0] if params else None
offset = params[1] if len(params) > 1 else 0
if workspace == "workspace_a" and offset == 0:
@@ -94,7 +106,9 @@ class TestWorkspaceMigrationIsolation:
# Mock _pg_table_exists, _pg_create_table, and _pg_migrate_workspace_data
from unittest.mock import patch
async def mock_migrate_workspace_data(db, legacy, new, workspace, expected_count, dim):
async def mock_migrate_workspace_data(
db, legacy, new, workspace, expected_count, dim
):
# Simulate migration by updating count
new_table_record_count["count"] = expected_count
return expected_count
@@ -123,7 +137,9 @@ class TestWorkspaceMigrationIsolation:
# Verify the migration function was called with the correct workspace
# The mock_migrate_workspace_data tracks that the migration was triggered
# with workspace_a data (2 records)
assert new_table_record_count["count"] == 2, "Should have migrated 2 records from workspace_a"
assert (
new_table_record_count["count"] == 2
), "Should have migrated 2 records from workspace_a"
@pytest.mark.asyncio
async def test_migration_without_workspace_raises_error(self):
@@ -175,7 +191,11 @@ class TestWorkspaceMigrationIsolation:
sql_lower = sql.lower()
# Count query for new table workspace data (should be 0 initially)
if "count(*)" in sql_lower and "model_1536d" in sql_lower and "where workspace" in sql_lower:
if (
"count(*)" in sql_lower
and "model_1536d" in sql_lower
and "where workspace" in sql_lower
):
return new_table_count
# Count query with workspace filter (legacy table)
@@ -184,7 +204,11 @@ class TestWorkspaceMigrationIsolation:
return {"count": 1} # 1 record for the queried workspace
# Count query for legacy table total (no workspace filter)
elif "count(*)" in sql_lower and "lightrag" in sql_lower and "where workspace" not in sql_lower:
elif (
"count(*)" in sql_lower
and "lightrag" in sql_lower
and "where workspace" not in sql_lower
):
return {"count": 3} # 3 total records in legacy
return {}
@@ -194,7 +218,9 @@ class TestWorkspaceMigrationIsolation:
from unittest.mock import patch
async def mock_migrate_workspace_data(db, legacy, new, workspace, expected_count, dim):
async def mock_migrate_workspace_data(
db, legacy, new, workspace, expected_count, dim
):
# Simulate migration by updating count
new_table_count["count"] = expected_count
return expected_count