Move shared lock validation to factory functions and fix test formatting
- Enforce init in lock factory functions - Simplify UnifiedLock class logic - Update lock safety tests - Fix line wrapping in test files
This commit is contained in:
@@ -163,30 +163,19 @@ class UnifiedLock(Generic[T]):
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
|
||||
# Then acquire the main lock
|
||||
if self._lock is not None:
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
self._lock.acquire()
|
||||
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
|
||||
level="INFO",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
# Acquire the main lock
|
||||
# Note: self._lock should never be None here as the check has been moved
|
||||
# to get_internal_lock() and get_data_init_lock() functions
|
||||
if self._is_async:
|
||||
await self._lock.acquire()
|
||||
else:
|
||||
# CRITICAL: Raise exception instead of allowing unprotected execution
|
||||
error_msg = (
|
||||
f"CRITICAL: Lock '{self._name}' is None - shared data not initialized. "
|
||||
f"Call initialize_share_data() before using locks!"
|
||||
)
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: {error_msg}",
|
||||
level="ERROR",
|
||||
enable_output=True,
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
self._lock.acquire()
|
||||
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquired lock {self._name} (async={self._is_async})",
|
||||
level="INFO",
|
||||
enable_output=self._enable_logging,
|
||||
)
|
||||
return self
|
||||
except Exception as e:
|
||||
# If main lock acquisition fails, release the async lock if it was acquired
|
||||
@@ -272,6 +261,10 @@ class UnifiedLock(Generic[T]):
|
||||
try:
|
||||
if self._is_async:
|
||||
raise RuntimeError("Use 'async with' for shared_storage lock")
|
||||
|
||||
# Acquire the main lock
|
||||
# Note: self._lock should never be None here as the check has been moved
|
||||
# to get_internal_lock() and get_data_init_lock() functions
|
||||
direct_log(
|
||||
f"== Lock == Process {self._pid}: Acquiring lock {self._name} (sync)",
|
||||
level="DEBUG",
|
||||
@@ -1077,6 +1070,10 @@ class _KeyedLockContext:
|
||||
|
||||
def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified storage lock for data consistency"""
|
||||
if _internal_lock is None:
|
||||
raise RuntimeError(
|
||||
"Shared data not initialized. Call initialize_share_data() before using locks!"
|
||||
)
|
||||
async_lock = _async_locks.get("internal_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_internal_lock,
|
||||
@@ -1107,6 +1104,10 @@ def get_storage_keyed_lock(
|
||||
|
||||
def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
|
||||
"""return unified data initialization lock for ensuring atomic data initialization"""
|
||||
if _data_init_lock is None:
|
||||
raise RuntimeError(
|
||||
"Shared data not initialized. Call initialize_share_data() before using locks!"
|
||||
)
|
||||
async_lock = _async_locks.get("data_init_lock") if _is_multiprocess else None
|
||||
return UnifiedLock(
|
||||
lock=_data_init_lock,
|
||||
|
||||
@@ -316,7 +316,9 @@ class TestPostgresDimensionMismatch:
|
||||
# Return different counts based on table name in query and migration state
|
||||
if "LIGHTRAG_DOC_CHUNKS_model_1536d" in query:
|
||||
# After migration: return migrated count, before: return 0
|
||||
return {"count": len(mock_records) if migration_done["value"] else 0}
|
||||
return {
|
||||
"count": len(mock_records) if migration_done["value"] else 0
|
||||
}
|
||||
# Legacy table always has 2 records (matching mock_records)
|
||||
return {"count": len(mock_records)}
|
||||
elif "pg_attribute" in query:
|
||||
@@ -360,16 +362,21 @@ class TestPostgresDimensionMismatch:
|
||||
|
||||
# Custom mock for _pg_migrate_workspace_data that updates migration_done
|
||||
async def mock_migrate_func(*args, **kwargs):
|
||||
migration_done["value"] = True # Set BEFORE returning so verification query sees it
|
||||
migration_done["value"] = (
|
||||
True # Set BEFORE returning so verification query sees it
|
||||
)
|
||||
return len(mock_records)
|
||||
|
||||
with patch(
|
||||
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||
side_effect=mock_table_exists,
|
||||
), patch(
|
||||
"lightrag.kg.postgres_impl._pg_migrate_workspace_data",
|
||||
side_effect=mock_migrate_func,
|
||||
) as mock_migrate:
|
||||
with (
|
||||
patch(
|
||||
"lightrag.kg.postgres_impl._pg_table_exists",
|
||||
side_effect=mock_table_exists,
|
||||
),
|
||||
patch(
|
||||
"lightrag.kg.postgres_impl._pg_migrate_workspace_data",
|
||||
side_effect=mock_migrate_func,
|
||||
) as mock_migrate,
|
||||
):
|
||||
# Call setup_table with matching 1536d
|
||||
await PGVectorStorage.setup_table(
|
||||
db,
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""
|
||||
Tests for UnifiedLock safety when lock is None.
|
||||
|
||||
This test module verifies that UnifiedLock raises RuntimeError instead of
|
||||
allowing unprotected execution when the underlying lock is None, preventing
|
||||
false security and potential race conditions.
|
||||
This test module verifies that get_internal_lock() and get_data_init_lock()
|
||||
raise RuntimeError when shared data is not initialized, preventing false
|
||||
security and potential race conditions.
|
||||
|
||||
Critical Bug 1: When self._lock is None, __aenter__ used to log WARNING but
|
||||
still return successfully, allowing critical sections to run without lock
|
||||
protection, causing race conditions and data corruption.
|
||||
Design: The None check has been moved from UnifiedLock.__aenter__/__enter__
|
||||
to the lock factory functions (get_internal_lock, get_data_init_lock) for
|
||||
early failure detection.
|
||||
|
||||
Critical Bug 1 (Fixed): When self._lock is None, the code would fail with
|
||||
AttributeError. Now the check is in factory functions for clearer errors.
|
||||
|
||||
Critical Bug 2: In __aexit__, when async_lock.release() fails, the error
|
||||
recovery logic would attempt to release it again, causing double-release issues.
|
||||
@@ -15,81 +18,52 @@ recovery logic would attempt to release it again, causing double-release issues.
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from lightrag.kg.shared_storage import UnifiedLock
|
||||
from lightrag.kg.shared_storage import (
|
||||
UnifiedLock,
|
||||
get_internal_lock,
|
||||
get_data_init_lock,
|
||||
finalize_share_data,
|
||||
)
|
||||
|
||||
|
||||
class TestUnifiedLockSafety:
|
||||
"""Test suite for UnifiedLock None safety checks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_lock_raises_on_none_async(self):
|
||||
def setup_method(self):
|
||||
"""Ensure shared data is finalized before each test."""
|
||||
finalize_share_data()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after each test."""
|
||||
finalize_share_data()
|
||||
|
||||
def test_get_internal_lock_raises_when_not_initialized(self):
|
||||
"""
|
||||
Test that UnifiedLock raises RuntimeError when lock is None (async mode).
|
||||
Test that get_internal_lock() raises RuntimeError when shared data is not initialized.
|
||||
|
||||
Scenario: Attempt to use UnifiedLock before initialize_share_data() is called.
|
||||
Expected: RuntimeError raised, preventing unprotected critical section execution.
|
||||
"""
|
||||
lock = UnifiedLock(
|
||||
lock=None, is_async=True, name="test_async_lock", enable_logging=False
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="shared data not initialized|Lock.*is None"
|
||||
):
|
||||
async with lock:
|
||||
# This code should NEVER execute
|
||||
pytest.fail(
|
||||
"Code inside lock context should not execute when lock is None"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_lock_raises_on_none_sync(self):
|
||||
"""
|
||||
Test that UnifiedLock raises RuntimeError when lock is None (sync mode).
|
||||
|
||||
Scenario: Attempt to use UnifiedLock with None lock in sync mode.
|
||||
Scenario: Call get_internal_lock() before initialize_share_data() is called.
|
||||
Expected: RuntimeError raised with clear error message.
|
||||
"""
|
||||
lock = UnifiedLock(
|
||||
lock=None, is_async=False, name="test_sync_lock", enable_logging=False
|
||||
)
|
||||
|
||||
This test verifies the None check has been moved to the factory function.
|
||||
"""
|
||||
with pytest.raises(
|
||||
RuntimeError, match="shared data not initialized|Lock.*is None"
|
||||
RuntimeError, match="Shared data not initialized.*initialize_share_data"
|
||||
):
|
||||
async with lock:
|
||||
# This code should NEVER execute
|
||||
pytest.fail(
|
||||
"Code inside lock context should not execute when lock is None"
|
||||
)
|
||||
get_internal_lock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_message_clarity(self):
|
||||
def test_get_data_init_lock_raises_when_not_initialized(self):
|
||||
"""
|
||||
Test that the error message clearly indicates the problem and solution.
|
||||
Test that get_data_init_lock() raises RuntimeError when shared data is not initialized.
|
||||
|
||||
Scenario: Lock is None and user tries to acquire it.
|
||||
Expected: Error message mentions 'shared data not initialized' and
|
||||
'initialize_share_data()'.
|
||||
Scenario: Call get_data_init_lock() before initialize_share_data() is called.
|
||||
Expected: RuntimeError raised with clear error message.
|
||||
|
||||
This test verifies the None check has been moved to the factory function.
|
||||
"""
|
||||
lock = UnifiedLock(
|
||||
lock=None,
|
||||
is_async=True,
|
||||
name="test_error_message",
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
async with lock:
|
||||
pass
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
# Verify error message contains helpful information
|
||||
assert (
|
||||
"shared data not initialized" in error_message.lower()
|
||||
or "lock" in error_message.lower()
|
||||
)
|
||||
assert "initialize_share_data" in error_message or "None" in error_message
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Shared data not initialized.*initialize_share_data"
|
||||
):
|
||||
get_data_init_lock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aexit_no_double_release_on_async_lock_failure(self):
|
||||
@@ -144,48 +118,3 @@ class TestUnifiedLockSafety:
|
||||
|
||||
# Main lock should have been released successfully
|
||||
main_lock.release.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aexit_recovery_on_main_lock_failure(self):
|
||||
"""
|
||||
Test that __aexit__ recovery logic works when main lock release fails.
|
||||
|
||||
Scenario: main_lock.release() fails before async_lock is attempted.
|
||||
Expected: Recovery logic should attempt to release async_lock to prevent
|
||||
resource leaks.
|
||||
|
||||
This verifies the recovery logic still works correctly with async_lock_released tracking.
|
||||
"""
|
||||
# Create mock locks
|
||||
main_lock = MagicMock()
|
||||
main_lock.acquire = MagicMock()
|
||||
|
||||
# Make main_lock.release() fail
|
||||
def mock_main_release_fail():
|
||||
raise RuntimeError("Main lock release failed")
|
||||
|
||||
main_lock.release = MagicMock(side_effect=mock_main_release_fail)
|
||||
|
||||
async_lock = AsyncMock()
|
||||
async_lock.acquire = AsyncMock()
|
||||
async_lock.release = MagicMock()
|
||||
|
||||
# Create UnifiedLock with both locks (sync mode with async_lock)
|
||||
lock = UnifiedLock(
|
||||
lock=main_lock, is_async=False, name="test_recovery", enable_logging=False
|
||||
)
|
||||
lock._async_lock = async_lock
|
||||
|
||||
# Try to use the lock - should fail during __aexit__
|
||||
try:
|
||||
async with lock:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
# Should get the main lock release error
|
||||
assert "Main lock release failed" in str(e)
|
||||
|
||||
# Main lock release should have been attempted
|
||||
main_lock.release.assert_called_once()
|
||||
|
||||
# Recovery logic should have attempted to release async_lock
|
||||
async_lock.release.assert_called_once()
|
||||
|
||||
@@ -48,7 +48,11 @@ class TestWorkspaceMigrationIsolation:
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Count query for new table workspace data (verification before migration)
|
||||
if "count(*)" in sql_lower and "model_1536d" in sql_lower and "where workspace" in sql_lower:
|
||||
if (
|
||||
"count(*)" in sql_lower
|
||||
and "model_1536d" in sql_lower
|
||||
and "where workspace" in sql_lower
|
||||
):
|
||||
return new_table_record_count # Initially 0
|
||||
|
||||
# Count query with workspace filter (legacy table) - for workspace count
|
||||
@@ -60,11 +64,19 @@ class TestWorkspaceMigrationIsolation:
|
||||
return {"count": 0}
|
||||
|
||||
# Count query for legacy table (total, no workspace filter)
|
||||
elif "count(*)" in sql_lower and "lightrag" in sql_lower and "where workspace" not in sql_lower:
|
||||
elif (
|
||||
"count(*)" in sql_lower
|
||||
and "lightrag" in sql_lower
|
||||
and "where workspace" not in sql_lower
|
||||
):
|
||||
return {"count": 5} # Total records in legacy
|
||||
|
||||
# SELECT with workspace filter for migration
|
||||
elif "select * from" in sql_lower and "where workspace" in sql_lower and multirows:
|
||||
elif (
|
||||
"select * from" in sql_lower
|
||||
and "where workspace" in sql_lower
|
||||
and multirows
|
||||
):
|
||||
workspace = params[0] if params else None
|
||||
offset = params[1] if len(params) > 1 else 0
|
||||
if workspace == "workspace_a" and offset == 0:
|
||||
@@ -94,7 +106,9 @@ class TestWorkspaceMigrationIsolation:
|
||||
# Mock _pg_table_exists, _pg_create_table, and _pg_migrate_workspace_data
|
||||
from unittest.mock import patch
|
||||
|
||||
async def mock_migrate_workspace_data(db, legacy, new, workspace, expected_count, dim):
|
||||
async def mock_migrate_workspace_data(
|
||||
db, legacy, new, workspace, expected_count, dim
|
||||
):
|
||||
# Simulate migration by updating count
|
||||
new_table_record_count["count"] = expected_count
|
||||
return expected_count
|
||||
@@ -123,7 +137,9 @@ class TestWorkspaceMigrationIsolation:
|
||||
# Verify the migration function was called with the correct workspace
|
||||
# The mock_migrate_workspace_data tracks that the migration was triggered
|
||||
# with workspace_a data (2 records)
|
||||
assert new_table_record_count["count"] == 2, "Should have migrated 2 records from workspace_a"
|
||||
assert (
|
||||
new_table_record_count["count"] == 2
|
||||
), "Should have migrated 2 records from workspace_a"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_without_workspace_raises_error(self):
|
||||
@@ -175,7 +191,11 @@ class TestWorkspaceMigrationIsolation:
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Count query for new table workspace data (should be 0 initially)
|
||||
if "count(*)" in sql_lower and "model_1536d" in sql_lower and "where workspace" in sql_lower:
|
||||
if (
|
||||
"count(*)" in sql_lower
|
||||
and "model_1536d" in sql_lower
|
||||
and "where workspace" in sql_lower
|
||||
):
|
||||
return new_table_count
|
||||
|
||||
# Count query with workspace filter (legacy table)
|
||||
@@ -184,7 +204,11 @@ class TestWorkspaceMigrationIsolation:
|
||||
return {"count": 1} # 1 record for the queried workspace
|
||||
|
||||
# Count query for legacy table total (no workspace filter)
|
||||
elif "count(*)" in sql_lower and "lightrag" in sql_lower and "where workspace" not in sql_lower:
|
||||
elif (
|
||||
"count(*)" in sql_lower
|
||||
and "lightrag" in sql_lower
|
||||
and "where workspace" not in sql_lower
|
||||
):
|
||||
return {"count": 3} # 3 total records in legacy
|
||||
|
||||
return {}
|
||||
@@ -194,7 +218,9 @@ class TestWorkspaceMigrationIsolation:
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
async def mock_migrate_workspace_data(db, legacy, new, workspace, expected_count, dim):
|
||||
async def mock_migrate_workspace_data(
|
||||
db, legacy, new, workspace, expected_count, dim
|
||||
):
|
||||
# Simulate migration by updating count
|
||||
new_table_count["count"] = expected_count
|
||||
return expected_count
|
||||
|
||||
Reference in New Issue
Block a user