300 lines
13 KiB
Python
300 lines
13 KiB
Python
"""Tests for infrastructure/ai/cache.py - embedding cache with LRU and TTL."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, Final
|
|
|
|
import pytest
|
|
|
|
from noteflow.infrastructure.ai.cache import (
|
|
DEFAULT_MAX_SIZE,
|
|
DEFAULT_TTL_SECONDS,
|
|
HASH_ALGORITHM,
|
|
CachedEmbedder,
|
|
CacheEntry,
|
|
EmbeddingCache,
|
|
EmbeddingCacheStats,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from tests.infrastructure.ai.conftest import MockEmbedder
|
|
|
|
# Test constants
|
|
SAMPLE_TEXT: Final[str] = "hello world"
|
|
SAMPLE_TEXT_ALT: Final[str] = "goodbye world"
|
|
SAMPLE_EMBEDDING: Final[tuple[float, ...]] = (0.1, 0.2, 0.3)
|
|
CREATED_TIME: Final[float] = 1000.0
|
|
EXPIRED_TIME: Final[float] = 5000.0
|
|
CURRENT_TIME: Final[float] = 2000.0
|
|
TTL_SECONDS_SHORT: Final[int] = 100
|
|
TTL_SECONDS_LONG: Final[int] = 10000
|
|
CACHE_SIZE_SMALL: Final[int] = 2
|
|
HIT_RATE_HALF: Final[float] = 0.5
|
|
HIT_RATE_ZERO: Final[float] = 0.0
|
|
EXPECTED_DEFAULT_MAX_SIZE: Final[int] = 1000
|
|
EXPECTED_DEFAULT_TTL: Final[int] = 3600
|
|
EXPECTED_HASH_ALGORITHM: Final[str] = "sha256"
|
|
ONE_ENTRY: Final[int] = 1
|
|
ZERO_ENTRIES: Final[int] = 0
|
|
SINGLE_CALL: Final[int] = 1
|
|
FIVE_HITS: Final[int] = 5
|
|
FIVE_MISSES: Final[int] = 5
|
|
|
|
|
|
class TestCacheEntry:
|
|
"""Tests for CacheEntry dataclass."""
|
|
|
|
def test_cache_entry_stores_embedding(self) -> None:
|
|
"""CacheEntry stores embedding tuple."""
|
|
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
|
|
assert entry.embedding == SAMPLE_EMBEDDING, "embedding should match"
|
|
|
|
def test_cache_entry_stores_created_at(self) -> None:
|
|
"""CacheEntry stores created_at timestamp."""
|
|
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
|
|
assert entry.created_at == CREATED_TIME, "created_at should match"
|
|
|
|
def test_cache_entry_is_frozen_dataclass(self) -> None:
|
|
"""CacheEntry is an immutable frozen dataclass."""
|
|
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
|
|
assert hasattr(entry, "__dataclass_fields__"), "should be a dataclass"
|
|
|
|
def test_is_cache_expired_after_ttl(self) -> None:
|
|
"""Entry is expired when current_time exceeds created_at + ttl."""
|
|
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
|
|
result = entry.is_cache_expired(TTL_SECONDS_SHORT, EXPIRED_TIME)
|
|
assert result is True, "should be expired"
|
|
|
|
def test_is_cache_not_expired_within_ttl(self) -> None:
|
|
"""Entry is not expired when within TTL."""
|
|
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
|
|
result = entry.is_cache_expired(TTL_SECONDS_LONG, CURRENT_TIME)
|
|
assert result is False, "should not be expired"
|
|
|
|
def test_is_cache_not_expired_at_creation(self) -> None:
|
|
"""Entry is not expired at creation time."""
|
|
entry = CacheEntry(embedding=SAMPLE_EMBEDDING, created_at=CREATED_TIME)
|
|
result = entry.is_cache_expired(TTL_SECONDS_SHORT, CREATED_TIME)
|
|
assert result is False, "should not be expired at creation"
|
|
|
|
|
|
class TestEmbeddingCacheStats:
|
|
"""Tests for EmbeddingCacheStats dataclass."""
|
|
|
|
def test_default_hits_is_zero(self) -> None:
|
|
"""EmbeddingCacheStats initializes with zero hits."""
|
|
stats = EmbeddingCacheStats()
|
|
assert stats.hits == ZERO_ENTRIES, "hits should be 0"
|
|
|
|
def test_default_misses_is_zero(self) -> None:
|
|
"""EmbeddingCacheStats initializes with zero misses."""
|
|
stats = EmbeddingCacheStats()
|
|
assert stats.misses == ZERO_ENTRIES, "misses should be 0"
|
|
|
|
def test_default_evictions_is_zero(self) -> None:
|
|
"""EmbeddingCacheStats initializes with zero evictions."""
|
|
stats = EmbeddingCacheStats()
|
|
assert stats.evictions == ZERO_ENTRIES, "evictions should be 0"
|
|
|
|
def test_default_expirations_is_zero(self) -> None:
|
|
"""EmbeddingCacheStats initializes with zero expirations."""
|
|
stats = EmbeddingCacheStats()
|
|
assert stats.expirations == ZERO_ENTRIES, "expirations should be 0"
|
|
|
|
def test_hit_rate_with_hits_and_misses(self) -> None:
|
|
"""hit_rate calculates correctly with both hits and misses."""
|
|
stats = EmbeddingCacheStats(hits=FIVE_HITS, misses=FIVE_MISSES)
|
|
assert stats.hit_rate == HIT_RATE_HALF, "hit rate should be 0.5"
|
|
|
|
def test_hit_rate_zero_total(self) -> None:
|
|
"""hit_rate returns 0.0 when no hits or misses."""
|
|
stats = EmbeddingCacheStats()
|
|
assert stats.hit_rate == HIT_RATE_ZERO, "hit rate should be 0.0"
|
|
|
|
|
|
class TestEmbeddingCache:
|
|
"""Tests for EmbeddingCache class."""
|
|
|
|
def test_default_max_size(self) -> None:
|
|
"""EmbeddingCache uses default max_size."""
|
|
cache = EmbeddingCache()
|
|
assert cache.max_size == DEFAULT_MAX_SIZE, "default max_size"
|
|
|
|
def test_default_ttl_seconds(self) -> None:
|
|
"""EmbeddingCache uses default ttl_seconds."""
|
|
cache = EmbeddingCache()
|
|
assert cache.ttl_seconds == DEFAULT_TTL_SECONDS, "default ttl_seconds"
|
|
|
|
def test_custom_max_size(self) -> None:
|
|
"""EmbeddingCache accepts custom max_size."""
|
|
cache = EmbeddingCache(max_size=CACHE_SIZE_SMALL)
|
|
assert cache.max_size == CACHE_SIZE_SMALL, "custom max_size"
|
|
|
|
def test_custom_ttl_seconds(self) -> None:
|
|
"""EmbeddingCache accepts custom ttl_seconds."""
|
|
cache = EmbeddingCache(ttl_seconds=TTL_SECONDS_SHORT)
|
|
assert cache.ttl_seconds == TTL_SECONDS_SHORT, "custom ttl_seconds"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_caching_same_text_returns_same_result(self, mock_embedder: MockEmbedder) -> None:
|
|
cache = EmbeddingCache()
|
|
result1 = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
result2 = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
assert result1 == result2, "same text should return same result"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_caching_different_texts_stores_separately(
|
|
self, mock_embedder: MockEmbedder
|
|
) -> None:
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
await cache.get_or_compute(SAMPLE_TEXT_ALT, mock_embedder)
|
|
size = await cache.size()
|
|
assert size == 2, "different texts should be stored separately"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_compute_cache_miss(self, mock_embedder: MockEmbedder) -> None:
|
|
cache = EmbeddingCache()
|
|
result = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
expected = mock_embedder.get_expected_embedding()
|
|
assert result == expected, "should return embedder result"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_compute_calls_embedder(self, mock_embedder: MockEmbedder) -> None:
|
|
"""get_or_compute records embedder call."""
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
assert SAMPLE_TEXT in mock_embedder.embed_calls, "embedder should be called"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_compute_cache_hit(self, mock_embedder: MockEmbedder) -> None:
|
|
"""get_or_compute returns cached value on cache hit."""
|
|
cache = EmbeddingCache()
|
|
first_result = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
second_result = await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
assert first_result == second_result, "results should match"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_compute_caches_result(self, mock_embedder: MockEmbedder) -> None:
|
|
"""get_or_compute only calls embedder once for same text."""
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
assert len(mock_embedder.embed_calls) == SINGLE_CALL, "embedder called once"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_returns_cached_embedding(self, mock_embedder: MockEmbedder) -> None:
|
|
"""get returns cached embedding without computing."""
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
result = await cache.get(SAMPLE_TEXT)
|
|
assert result is not None, "should return cached value"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_matches_cached_value(self, mock_embedder: MockEmbedder) -> None:
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
result = await cache.get(SAMPLE_TEXT)
|
|
expected = mock_embedder.get_expected_embedding()
|
|
assert result == expected, "should match embedding"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_returns_none_for_uncached(self) -> None:
|
|
"""get returns None for uncached text."""
|
|
cache = EmbeddingCache()
|
|
result = await cache.get(SAMPLE_TEXT)
|
|
assert result is None, "should return None for uncached"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_returns_cleared_count(self, mock_embedder: MockEmbedder) -> None:
|
|
"""clear returns number of entries cleared."""
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
count = await cache.clear()
|
|
assert count == ONE_ENTRY, "should return count of cleared entries"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_empties_cache(self, mock_embedder: MockEmbedder) -> None:
|
|
"""clear removes all cached entries."""
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
await cache.clear()
|
|
size = await cache.size()
|
|
assert size == ZERO_ENTRIES, "cache should be empty"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_size_initially_zero(self) -> None:
|
|
"""size returns zero for empty cache."""
|
|
cache = EmbeddingCache()
|
|
initial_size = await cache.size()
|
|
assert initial_size == ZERO_ENTRIES, "initial size should be 0"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_size_increments_on_cache(self, mock_embedder: MockEmbedder) -> None:
|
|
"""size increases after caching."""
|
|
cache = EmbeddingCache()
|
|
await cache.get_or_compute(SAMPLE_TEXT, mock_embedder)
|
|
new_size = await cache.size()
|
|
assert new_size == ONE_ENTRY, "size should be 1"
|
|
|
|
def test_stats_snapshot_returns_stats(self) -> None:
|
|
"""stats_snapshot returns EmbeddingCacheStats."""
|
|
cache = EmbeddingCache()
|
|
stats = cache.stats_snapshot()
|
|
assert isinstance(stats, EmbeddingCacheStats), "should return stats"
|
|
|
|
|
|
class TestCachedEmbedder:
|
|
"""Tests for CachedEmbedder wrapper class."""
|
|
|
|
def test_initialization_max_size(self, mock_embedder: MockEmbedder) -> None:
|
|
cached = CachedEmbedder(mock_embedder, max_size=CACHE_SIZE_SMALL)
|
|
assert cached.cache.max_size == CACHE_SIZE_SMALL, "max_size should match"
|
|
|
|
def test_initialization_ttl_seconds(self, mock_embedder: MockEmbedder) -> None:
|
|
cached = CachedEmbedder(mock_embedder, ttl_seconds=TTL_SECONDS_SHORT)
|
|
assert cached.cache.ttl_seconds == TTL_SECONDS_SHORT, "ttl should match"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_returns_embedding(self, mock_embedder: MockEmbedder) -> None:
|
|
cached = CachedEmbedder(mock_embedder)
|
|
result = await cached.embed(SAMPLE_TEXT)
|
|
expected = mock_embedder.get_expected_embedding()
|
|
assert result == expected, "should return embedding"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_calls_underlying_embedder(self, mock_embedder: MockEmbedder) -> None:
|
|
"""embed calls underlying embedder."""
|
|
cached = CachedEmbedder(mock_embedder)
|
|
await cached.embed(SAMPLE_TEXT)
|
|
assert SAMPLE_TEXT in mock_embedder.embed_calls, "should call embedder"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_caches_results(self, mock_embedder: MockEmbedder) -> None:
|
|
"""embed caches results for subsequent calls."""
|
|
cached = CachedEmbedder(mock_embedder)
|
|
await cached.embed(SAMPLE_TEXT)
|
|
await cached.embed(SAMPLE_TEXT)
|
|
assert len(mock_embedder.embed_calls) == SINGLE_CALL, "only call once"
|
|
|
|
def test_cache_property_exposes_cache(self, mock_embedder: MockEmbedder) -> None:
|
|
"""cache property provides access to underlying EmbeddingCache."""
|
|
cached = CachedEmbedder(mock_embedder)
|
|
assert isinstance(cached.cache, EmbeddingCache), "should expose cache"
|
|
|
|
|
|
class TestCacheConstants:
|
|
"""Tests for cache module constants."""
|
|
|
|
def test_default_max_size_value(self) -> None:
|
|
"""DEFAULT_MAX_SIZE is expected value."""
|
|
assert DEFAULT_MAX_SIZE == EXPECTED_DEFAULT_MAX_SIZE, "default max size"
|
|
|
|
def test_default_ttl_seconds_value(self) -> None:
|
|
"""DEFAULT_TTL_SECONDS is one hour."""
|
|
assert DEFAULT_TTL_SECONDS == EXPECTED_DEFAULT_TTL, "default TTL"
|
|
|
|
def test_hash_algorithm_value(self) -> None:
|
|
"""HASH_ALGORITHM is sha256."""
|
|
assert HASH_ALGORITHM == EXPECTED_HASH_ALGORITHM, "hash algorithm"
|