Files
noteflow/tests/infrastructure/ai/test_cache.py
Travis Vasceannie d8090a98e8
Some checks failed
CI / test-typescript (push) Has been cancelled
CI / test-rust (push) Has been cancelled
CI / test-python (push) Has been cancelled
ci/cd fixes
2026-01-26 00:28:15 +00:00

300 lines
13 KiB
Python

"""Tests for infrastructure/ai/cache.py - embedding cache with LRU and TTL."""
from __future__ import annotations
from typing import TYPE_CHECKING, Final
import pytest
from noteflow.infrastructure.ai.cache import (
DEFAULT_MAX_SIZE,
DEFAULT_TTL_SECONDS,
HASH_ALGORITHM,
CachedEmbedder,
CacheEntry,
EmbeddingCache,
EmbeddingCacheStats,
)
if TYPE_CHECKING:
from tests.infrastructure.ai.conftest import MockEmbedder
# Test constants
SAMPLE_TEXT: Final[str] = "hello world"
SAMPLE_TEXT_ALT: Final[str] = "goodbye world"
SAMPLE_EMBEDDING: Final[tuple[float, ...]] = (0.1, 0.2, 0.3)
CREATED_TIME: Final[float] = 1000.0
EXPIRED_TIME: Final[float] = 5000.0
CURRENT_TIME: Final[float] = 2000.0
TTL_SECONDS_SHORT: Final[int] = 100
TTL_SECONDS_LONG: Final[int] = 10000
CACHE_SIZE_SMALL: Final[int] = 2
HIT_RATE_HALF: Final[float] = 0.5
HIT_RATE_ZERO: Final[float] = 0.0
EXPECTED_DEFAULT_MAX_SIZE: Final[int] = 1000
EXPECTED_DEFAULT_TTL: Final[int] = 3600
EXPECTED_HASH_ALGORITHM: Final[str] = "sha256"
ONE_ENTRY: Final[int] = 1
ZERO_ENTRIES: Final[int] = 0
SINGLE_CALL: Final[int] = 1
FIVE_HITS: Final[int] = 5
FIVE_MISSES: Final[int] = 5
class TestCacheEntry:
"""Tests for CacheEntry dataclass."""
def test_cache_entry_stores_embedding(self) -> None:
"""CacheEntry stores embedding tuple."""
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
assert entry.embedding == SAMPLE_EMBEDDING, "embedding should match"
def test_cache_entry_stores_created_at(self) -> None:
"""CacheEntry stores created_at timestamp."""
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
assert entry.created_at == CREATED_TIME, "created_at should match"
def test_cache_entry_is_frozen_dataclass(self) -> None:
"""CacheEntry is an immutable frozen dataclass."""
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
assert hasattr(entry, "__dataclass_fields__"), "should be a dataclass"
def test_is_cache_expired_after_ttl(self) -> None:
"""Entry is expired when current_time exceeds created_at + ttl."""
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
result = entry.is_cache_expired(TTL_SECONDS_SHORT, EXPIRED_TIME)
assert result is True, "should be expired"
def test_is_cache_not_expired_within_ttl(self) -> None:
"""Entry is not expired when within TTL."""
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
result = entry.is_cache_expired(TTL_SECONDS_LONG, CURRENT_TIME)
assert result is False, "should not be expired"
def test_is_cache_not_expired_at_creation(self) -> None:
"""Entry is not expired at creation time."""
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
result = entry.is_cache_expired(TTL_SECONDS_SHORT, CREATED_TIME)
assert result is False, "should not be expired at creation"
class TestEmbeddingCacheStats:
"""Tests for EmbeddingCacheStats dataclass."""
def test_default_hits_is_zero(self) -> None:
"""EmbeddingCacheStats initializes with zero hits."""
stats = EmbeddingCacheStats()
assert stats.hits == ZERO_ENTRIES, "hits should be 0"
def test_default_misses_is_zero(self) -> None:
"""EmbeddingCacheStats initializes with zero misses."""
stats = EmbeddingCacheStats()
assert stats.misses == ZERO_ENTRIES, "misses should be 0"
def test_default_evictions_is_zero(self) -> None:
"""EmbeddingCacheStats initializes with zero evictions."""
stats = EmbeddingCacheStats()
assert stats.evictions == ZERO_ENTRIES, "evictions should be 0"
def test_default_expirations_is_zero(self) -> None:
"""EmbeddingCacheStats initializes with zero expirations."""
stats = EmbeddingCacheStats()
assert stats.expirations == ZERO_ENTRIES, "expirations should be 0"
def test_hit_rate_with_hits_and_misses(self) -> None:
"""hit_rate calculates correctly with both hits and misses."""
stats = EmbeddingCacheStats(hits=FIVE_HITS, misses=FIVE_MISSES)
assert stats.hit_rate == HIT_RATE_HALF, "hit rate should be 0.5"
def test_hit_rate_zero_total(self) -> None:
"""hit_rate returns 0.0 when no hits or misses."""
stats = EmbeddingCacheStats()
assert stats.hit_rate == HIT_RATE_ZERO, "hit rate should be 0.0"
class TestEmbeddingCache:
"""Tests for EmbeddingCache class."""
def test_default_max_size(self) -> None:
"""EmbeddingCache uses default max_size."""
cache = EmbeddingCache()
assert cache.max_size == DEFAULT_MAX_SIZE, "default max_size"
def test_default_ttl_seconds(self) -> None:
"""EmbeddingCache uses default ttl_seconds."""
cache = EmbeddingCache()
assert cache.ttl_seconds == DEFAULT_TTL_SECONDS, "default ttl_seconds"
def test_custom_max_size(self) -> None:
"""EmbeddingCache accepts custom max_size."""
cache = EmbeddingCache(max_size=CACHE_SIZE_SMALL)
assert cache.max_size == CACHE_SIZE_SMALL, "custom max_size"
def test_custom_ttl_seconds(self) -> None:
"""EmbeddingCache accepts custom ttl_seconds."""
cache = EmbeddingCache(ttl_seconds=TTL_SECONDS_SHORT)
assert cache.ttl_seconds == TTL_SECONDS_SHORT, "custom ttl_seconds"
@pytest.mark.asyncio
async def test_caching_same_text_returns_same_result(self, mock_embedder: MockEmbedder) -> None:
cache = EmbeddingCache()
result1 = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
result2 = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
assert result1 == result2, "same text should return same result"
@pytest.mark.asyncio
async def test_caching_different_texts_stores_separately(
self, mock_embedder: MockEmbedder
) -> None:
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
await cache.get_or_compute(SAMPLE_TEXT_ALT, mock_embedder)
size = await cache.size()
assert size == 2, "different texts should be stored separately"
@pytest.mark.asyncio
async def test_get_or_compute_cache_miss(self, mock_embedder: MockEmbedder) -> None:
cache = EmbeddingCache()
result = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
expected = mock_embedder.get_expected_embedding()
assert result == expected, "should return embedder result"
@pytest.mark.asyncio
async def test_get_or_compute_calls_embedder(self, mock_embedder: MockEmbedder) -> None:
"""get_or_compute records embedder call."""
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
assert SAMPLE_TEXT in mock_embedder.embed_calls, "embedder should be called"
@pytest.mark.asyncio
async def test_get_or_compute_cache_hit(self, mock_embedder: MockEmbedder) -> None:
"""get_or_compute returns cached value on cache hit."""
cache = EmbeddingCache()
first_result = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
second_result = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
assert first_result == second_result, "results should match"
@pytest.mark.asyncio
async def test_get_or_compute_caches_result(self, mock_embedder: MockEmbedder) -> None:
"""get_or_compute only calls embedder once for same text."""
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
assert len(mock_embedder.embed_calls) == SINGLE_CALL, "embedder called once"
@pytest.mark.asyncio
async def test_get_returns_cached_embedding(self, mock_embedder: MockEmbedder) -> None:
"""get returns cached embedding without computing."""
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
result = await cache.get(SAMPLE_TEXT)
assert result is not None, "should return cached value"
@pytest.mark.asyncio
async def test_get_matches_cached_value(self, mock_embedder: MockEmbedder) -> None:
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
result = await cache.get(SAMPLE_TEXT)
expected = mock_embedder.get_expected_embedding()
assert result == expected, "should match embedding"
@pytest.mark.asyncio
async def test_get_returns_none_for_uncached(self) -> None:
"""get returns None for uncached text."""
cache = EmbeddingCache()
result = await cache.get(SAMPLE_TEXT)
assert result is None, "should return None for uncached"
@pytest.mark.asyncio
async def test_clear_returns_cleared_count(self, mock_embedder: MockEmbedder) -> None:
"""clear returns number of entries cleared."""
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
count = await cache.clear()
assert count == ONE_ENTRY, "should return count of cleared entries"
@pytest.mark.asyncio
async def test_clear_empties_cache(self, mock_embedder: MockEmbedder) -> None:
"""clear removes all cached entries."""
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
await cache.clear()
size = await cache.size()
assert size == ZERO_ENTRIES, "cache should be empty"
@pytest.mark.asyncio
async def test_size_initially_zero(self) -> None:
"""size returns zero for empty cache."""
cache = EmbeddingCache()
initial_size = await cache.size()
assert initial_size == ZERO_ENTRIES, "initial size should be 0"
@pytest.mark.asyncio
async def test_size_increments_on_cache(self, mock_embedder: MockEmbedder) -> None:
"""size increases after caching."""
cache = EmbeddingCache()
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
new_size = await cache.size()
assert new_size == ONE_ENTRY, "size should be 1"
def test_stats_snapshot_returns_stats(self) -> None:
"""stats_snapshot returns EmbeddingCacheStats."""
cache = EmbeddingCache()
stats = cache.stats_snapshot()
assert isinstance(stats, EmbeddingCacheStats), "should return stats"
class TestCachedEmbedder:
"""Tests for CachedEmbedder wrapper class."""
def test_initialization_max_size(self, mock_embedder: MockEmbedder) -> None:
cached = CachedEmbedder(mock_embedder, max_size=CACHE_SIZE_SMALL)
assert cached.cache.max_size == CACHE_SIZE_SMALL, "max_size should match"
def test_initialization_ttl_seconds(self, mock_embedder: MockEmbedder) -> None:
cached = CachedEmbedder(mock_embedder, ttl_seconds=TTL_SECONDS_SHORT)
assert cached.cache.ttl_seconds == TTL_SECONDS_SHORT, "ttl should match"
@pytest.mark.asyncio
async def test_embed_returns_embedding(self, mock_embedder: MockEmbedder) -> None:
cached = CachedEmbedder(mock_embedder)
result = await cached.embed(SAMPLE_TEXT)
expected = mock_embedder.get_expected_embedding()
assert result == expected, "should return embedding"
@pytest.mark.asyncio
async def test_embed_calls_underlying_embedder(self, mock_embedder: MockEmbedder) -> None:
"""embed calls underlying embedder."""
cached = CachedEmbedder(mock_embedder)
await cached.embed(SAMPLE_TEXT)
assert SAMPLE_TEXT in mock_embedder.embed_calls, "should call embedder"
@pytest.mark.asyncio
async def test_embed_caches_results(self, mock_embedder: MockEmbedder) -> None:
"""embed caches results for subsequent calls."""
cached = CachedEmbedder(mock_embedder)
await cached.embed(SAMPLE_TEXT)
await cached.embed(SAMPLE_TEXT)
assert len(mock_embedder.embed_calls) == SINGLE_CALL, "only call once"
def test_cache_property_exposes_cache(self, mock_embedder: MockEmbedder) -> None:
"""cache property provides access to underlying EmbeddingCache."""
cached = CachedEmbedder(mock_embedder)
assert isinstance(cached.cache, EmbeddingCache), "should expose cache"
class TestCacheConstants:
"""Tests for cache module constants."""
def test_default_max_size_value(self) -> None:
"""DEFAULT_MAX_SIZE is expected value."""
assert DEFAULT_MAX_SIZE == EXPECTED_DEFAULT_MAX_SIZE, "default max size"
def test_default_ttl_seconds_value(self) -> None:
"""DEFAULT_TTL_SECONDS is one hour."""
assert DEFAULT_TTL_SECONDS == EXPECTED_DEFAULT_TTL, "default TTL"
def test_hash_algorithm_value(self) -> None:
"""HASH_ALGORITHM is sha256."""
assert HASH_ALGORITHM == EXPECTED_HASH_ALGORITHM, "hash algorithm"