250 lines
8.3 KiB
Python
250 lines
8.3 KiB
Python
"""PostgreSQL testcontainer fixtures and utilities.
|
|
|
|
Provides a singleton PostgreSQL container for integration tests. When running with
|
|
pytest-xdist, each worker gets its own isolated database within the shared container
|
|
for safe parallel execution.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import os
|
|
import threading
|
|
import time
|
|
from importlib import import_module
|
|
from typing import TYPE_CHECKING
|
|
from urllib.parse import quote
|
|
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from noteflow.infrastructure.persistence.models import Base
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Self
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
|
|
|
|
def get_xdist_worker_id() -> str:
|
|
"""Get pytest-xdist worker ID (e.g., 'gw0', 'gw1') or 'main' for sequential runs."""
|
|
return os.environ.get("PYTEST_XDIST_WORKER", "main")
|
|
|
|
|
|
def get_worker_database_name() -> str:
|
|
"""Get database name for current xdist worker (e.g., 'noteflow_test_gw0')."""
|
|
worker_id = get_xdist_worker_id()
|
|
if worker_id == "main":
|
|
return "noteflow_test"
|
|
return f"noteflow_test_{worker_id}"
|
|
|
|
|
|
class PgTestContainer:
|
|
"""Minimal Postgres testcontainer wrapper with per-worker database support."""
|
|
|
|
def __init__(
|
|
self,
|
|
image: str = "pgvector/pgvector:pg16",
|
|
username: str = "test",
|
|
password: str = "test",
|
|
dbname: str = "postgres",
|
|
port: int = 5432,
|
|
) -> None:
|
|
self.username = username
|
|
self.password = password
|
|
self.dbname = dbname
|
|
self.port = port
|
|
self._created_databases: set[str] = set()
|
|
self._db_lock = threading.Lock()
|
|
|
|
container_module = import_module("testcontainers.core.container")
|
|
docker_container_cls = container_module.DockerContainer
|
|
self._container = (
|
|
docker_container_cls(image)
|
|
.with_env("POSTGRES_USER", username)
|
|
.with_env("POSTGRES_PASSWORD", password)
|
|
.with_env("POSTGRES_DB", dbname)
|
|
.with_exposed_ports(port)
|
|
)
|
|
|
|
def start(self) -> Self:
|
|
self._container.start()
|
|
self._wait_until_ready()
|
|
return self
|
|
|
|
def stop(self) -> None:
|
|
self._container.stop()
|
|
|
|
def get_host(self) -> str:
|
|
return str(self._container.get_container_host_ip())
|
|
|
|
def get_port(self) -> int:
|
|
return int(self._container._get_exposed_port(self.port))
|
|
|
|
def get_connection_url(self, dbname: str | None = None) -> str:
|
|
"""Return SQLAlchemy connection URL for specified database."""
|
|
host = self.get_host()
|
|
port = self.get_port()
|
|
db = dbname or self.dbname
|
|
quoted_password = quote(self.password, safe=" +")
|
|
return f"postgresql+psycopg2://{self.username}:{quoted_password}@{host}:{port}/{db}"
|
|
|
|
def ensure_database_exists(self, dbname: str) -> None:
|
|
"""Create database if it doesn't exist (thread-safe, idempotent)."""
|
|
with self._db_lock:
|
|
if dbname in self._created_databases:
|
|
return
|
|
|
|
escaped_password = self.password.replace("'", "'\"'\"'")
|
|
# Check if database exists, create if not
|
|
check_cmd = [
|
|
"sh",
|
|
"-c",
|
|
(
|
|
f"PGPASSWORD='{escaped_password}' "
|
|
f"psql --username {self.username} --dbname postgres --host 127.0.0.1 "
|
|
f"-tAc \"SELECT 1 FROM pg_database WHERE datname='{dbname}'\""
|
|
),
|
|
]
|
|
result = self._container.exec(check_cmd)
|
|
if result.output and result.output.strip() == b"1":
|
|
self._created_databases.add(dbname)
|
|
return
|
|
|
|
# Create the database
|
|
create_cmd = [
|
|
"sh",
|
|
"-c",
|
|
(
|
|
f"PGPASSWORD='{escaped_password}' "
|
|
f"psql --username {self.username} --dbname postgres --host 127.0.0.1 "
|
|
f"-c 'CREATE DATABASE {dbname}'"
|
|
),
|
|
]
|
|
result = self._container.exec(create_cmd)
|
|
if result.exit_code != 0:
|
|
error_msg = result.output.decode(errors="ignore") if result.output else ""
|
|
if "already exists" not in error_msg:
|
|
raise RuntimeError(f"Failed to create database {dbname}: {error_msg}")
|
|
|
|
self._created_databases.add(dbname)
|
|
|
|
def _wait_until_ready(self, timeout: float = 30.0, interval: float = 0.5) -> None:
|
|
start_time = time.time()
|
|
escaped_password = self.password.replace("'", "'\"'\"'")
|
|
cmd = [
|
|
"sh",
|
|
"-c",
|
|
(
|
|
f"PGPASSWORD='{escaped_password}' "
|
|
f"psql --username {self.username} --dbname {self.dbname} --host 127.0.0.1 "
|
|
"-c 'select 1;'"
|
|
),
|
|
]
|
|
last_error: str | None = None
|
|
|
|
while True:
|
|
result = self._container.exec(cmd)
|
|
if result.exit_code == 0:
|
|
return
|
|
if result.output:
|
|
last_error = result.output.decode(errors="ignore")
|
|
if time.time() - start_time > timeout:
|
|
raise TimeoutError(
|
|
"Postgres container did not become ready in time"
|
|
+ (f": {last_error}" if last_error else "")
|
|
)
|
|
time.sleep(interval)
|
|
|
|
|
|
# Module-level container singleton
|
|
_container: PgTestContainer | None = None
|
|
_container_lock = threading.Lock()
|
|
_cleanup_registered = False
|
|
|
|
|
|
def get_or_create_container() -> tuple[PgTestContainer, str]:
|
|
"""Get or create container and return (container, worker-specific database URL)."""
|
|
global _container, _cleanup_registered
|
|
|
|
with _container_lock:
|
|
if _container is None:
|
|
_container = PgTestContainer().start()
|
|
|
|
if not _cleanup_registered:
|
|
atexit.register(_atexit_cleanup)
|
|
_cleanup_registered = True
|
|
|
|
# Get worker-specific database name and ensure it exists
|
|
dbname = get_worker_database_name()
|
|
_container.ensure_database_exists(dbname)
|
|
|
|
# Return async database URL for the worker's database
|
|
url = _container.get_connection_url(dbname)
|
|
async_url = url.replace("postgresql+psycopg2://", "postgresql+asyncpg://")
|
|
return _container, async_url
|
|
|
|
|
|
def _atexit_cleanup() -> None:
|
|
global _container
|
|
if _container is not None:
|
|
try:
|
|
_container.stop()
|
|
except Exception:
|
|
pass
|
|
_container = None
|
|
|
|
|
|
def stop_container() -> None:
|
|
global _container
|
|
with _container_lock:
|
|
if _container is not None:
|
|
_container.stop()
|
|
_container = None
|
|
|
|
|
|
async def initialize_test_schema(conn: AsyncConnection) -> None:
|
|
"""Initialize test database schema with pgvector extension and all tables."""
|
|
await conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
|
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
|
await conn.execute(text("CREATE SCHEMA noteflow"))
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
# Seed default workspace and user for FK constraints
|
|
default_id = "00000000-0000-0000-0000-000000000001"
|
|
await conn.execute(
|
|
text(
|
|
"INSERT INTO noteflow.workspaces "
|
|
"(id, name, is_default, settings, created_at, updated_at, metadata) "
|
|
"VALUES (:id, 'Default Workspace', true, '{}', NOW(), NOW(), '{}') "
|
|
"ON CONFLICT (id) DO NOTHING"
|
|
),
|
|
{"id": default_id},
|
|
)
|
|
await conn.execute(
|
|
text(
|
|
"INSERT INTO noteflow.users (id, display_name, created_at, updated_at, metadata) "
|
|
"VALUES (:id, 'Test User', NOW(), NOW(), '{}') "
|
|
"ON CONFLICT (id) DO NOTHING"
|
|
),
|
|
{"id": default_id},
|
|
)
|
|
|
|
|
|
async def cleanup_test_schema(conn: AsyncConnection) -> None:
|
|
await conn.execute(text("DROP SCHEMA IF EXISTS noteflow CASCADE"))
|
|
|
|
|
|
def create_test_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
|
return async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
|
|
def create_test_engine(database_url: str) -> AsyncEngine:
|
|
return create_async_engine(database_url, echo=False)
|