Files
noteflow/support/db_utils.py
Travis Vasceannie 301482c410
Some checks failed
CI / test-python (push) Successful in 8m41s
CI / test-typescript (push) Failing after 6m2s
CI / test-rust (push) Failing after 4m28s
Refactor: Improve CI workflow robustness and test environment variable management, and enable parallel quality test execution.
2026-01-26 02:04:38 +00:00

250 lines
8.3 KiB
Python

"""PostgreSQL testcontainer fixtures and utilities.
Provides a singleton PostgreSQL container for integration tests. When running with
pytest-xdist, each worker gets its own isolated database within the shared container
for safe parallel execution.
"""
from __future__ import annotations
import atexit
import os
import threading
import time
from importlib import import_module
from typing import TYPE_CHECKING
from urllib.parse import quote
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from noteflow.infrastructure.persistence.models import Base
if TYPE_CHECKING:
from typing import Self
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
def get_xdist_worker_id() -> str:
"""Get pytest-xdist worker ID (e.g., 'gw0', 'gw1') or 'main' for sequential runs."""
return os.environ.get("PYTEST_XDIST_WORKER", "main")
def get_worker_database_name() -> str:
"""Get database name for current xdist worker (e.g., 'noteflow_test_gw0')."""
worker_id = get_xdist_worker_id()
if worker_id == "main":
return "noteflow_test"
return f"noteflow_test_{worker_id}"
class PgTestContainer:
"""Minimal Postgres testcontainer wrapper with per-worker database support."""
def __init__(
self,
image: str = "pgvector/pgvector:pg16",
username: str = "test",
password: str = "test",
dbname: str = "postgres",
port: int = 5432,
) -> None:
self.username = username
self.password = password
self.dbname = dbname
self.port = port
self._created_databases: set[str] = set()
self._db_lock = threading.Lock()
container_module = import_module("testcontainers.core.container")
docker_container_cls = container_module.DockerContainer
self._container = (
docker_container_cls(image)
.with_env("POSTGRES_USER", username)
.with_env("POSTGRES_PASSWORD", password)
.with_env("POSTGRES_DB", dbname)
.with_exposed_ports(port)
)
def start(self) -> Self:
self._container.start()
self._wait_until_ready()
return self
def stop(self) -> None:
self._container.stop()
def get_host(self) -> str:
return str(self._container.get_container_host_ip())
def get_port(self) -> int:
return int(self._container._get_exposed_port(self.port))
def get_connection_url(self, dbname: str | None = None) -> str:
"""Return SQLAlchemy connection URL for specified database."""
host = self.get_host()
port = self.get_port()
db = dbname or self.dbname
quoted_password = quote(self.password, safe=" +")
return f"postgresql+psycopg2://{self.username}:{quoted_password}@{host}:{port}/{db}"
def ensure_database_exists(self, dbname: str) -> None:
"""Create database if it doesn't exist (thread-safe, idempotent)."""
with self._db_lock:
if dbname in self._created_databases:
return
escaped_password = self.password.replace("'", "'\"'\"'")
# Check if database exists, create if not
check_cmd = [
"sh",
"-c",
(
f"PGPASSWORD='{escaped_password}' "
f"psql --username {self.username} --dbname postgres --host 127.0.0.1 "
f"-tAc \"SELECT 1 FROM pg_database WHERE datname='{dbname}'\""
),
]
result = self._container.exec(check_cmd)
if result.output and result.output.strip() == b"1":
self._created_databases.add(dbname)
return
# Create the database
create_cmd = [
"sh",
"-c",
(
f"PGPASSWORD='{escaped_password}' "
f"psql --username {self.username} --dbname postgres --host 127.0.0.1 "
f"-c 'CREATE DATABASE {dbname}'"
),
]
result = self._container.exec(create_cmd)
if result.exit_code != 0:
error_msg = result.output.decode(errors="ignore") if result.output else ""
if "already exists" not in error_msg:
raise RuntimeError(f"Failed to create database {dbname}: {error_msg}")
self._created_databases.add(dbname)
def _wait_until_ready(self, timeout: float = 30.0, interval: float = 0.5) -> None:
start_time = time.time()
escaped_password = self.password.replace("'", "'\"'\"'")
cmd = [
"sh",
"-c",
(
f"PGPASSWORD='{escaped_password}' "
f"psql --username {self.username} --dbname {self.dbname} --host 127.0.0.1 "
"-c 'select 1;'"
),
]
last_error: str | None = None
while True:
result = self._container.exec(cmd)
if result.exit_code == 0:
return
if result.output:
last_error = result.output.decode(errors="ignore")
if time.time() - start_time > timeout:
raise TimeoutError(
"Postgres container did not become ready in time"
+ (f": {last_error}" if last_error else "")
)
time.sleep(interval)
# Module-level container singleton
_container: PgTestContainer | None = None
_container_lock = threading.Lock()
_cleanup_registered = False
def get_or_create_container() -> tuple[PgTestContainer, str]:
"""Get or create container and return (container, worker-specific database URL)."""
global _container, _cleanup_registered
with _container_lock:
if _container is None:
_container = PgTestContainer().start()
if not _cleanup_registered:
atexit.register(_atexit_cleanup)
_cleanup_registered = True
# Get worker-specific database name and ensure it exists
dbname = get_worker_database_name()
_container.ensure_database_exists(dbname)
# Return async database URL for the worker's database
url = _container.get_connection_url(dbname)
async_url = url.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
return _container, async_url
def _atexit_cleanup() -> None:
global _container
if _container is not None:
try:
_container.stop()
except Exception:
pass
_container = None
def stop_container() -> None:
global _container
with _container_lock:
if _container is not None:
_container.stop()
_container = None
async def initialize_test_schema(conn: AsyncConnection) -> None:
"""Initialize test database schema with pgvector extension and all tables."""
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
await conn.execute(text("CREATE SCHEMA noteflow"))
await conn.run_sync(Base.metadata.create_all)
# Seed default workspace and user for FK constraints
default_id = "00000000-0000-0000-0000-000000000001"
await conn.execute(
text(
"INSERT INTO noteflow.workspaces "
"(id, name, is_default, settings, created_at, updated_at, metadata) "
"VALUES (:id, 'Default Workspace', true, '{}', NOW(), NOW(), '{}') "
"ON CONFLICT (id) DO NOTHING"
),
{"id": default_id},
)
await conn.execute(
text(
"INSERT INTO noteflow.users (id, display_name, created_at, updated_at, metadata) "
"VALUES (:id, 'Test User', NOW(), NOW(), '{}') "
"ON CONFLICT (id) DO NOTHING"
),
{"id": default_id},
)
async def cleanup_test_schema(conn: AsyncConnection) -> None:
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
def create_test_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
return async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
def create_test_engine(database_url: str) -> AsyncEngine:
return create_async_engine(database_url, echo=False)