- Deleted .env.example file as it is no longer needed. - Added .gitignore to manage ignored files and directories. - Introduced CLAUDE.md for AI provider integration documentation. - Created dev.sh for development setup and scripts. - Updated Dockerfile and Dockerfile.production for improved build processes. - Added multiple test files and directories for comprehensive testing. - Introduced new utility and service files for enhanced functionality. - Organized codebase with new directories and files for better maintainability.
883 lines
32 KiB
Python
883 lines
32 KiB
Python
"""
|
|
Comprehensive tests for AI Manager fixes.
|
|
|
|
Tests the specific bug fixes implemented in the AI Manager:
|
|
1. GroqProvider incorrectly using OpenAI client for embeddings (now raises NotImplementedError)
|
|
2. Unused latency calculations (now properly logged)
|
|
3. Deprecated .dict() method (now uses .model_dump())
|
|
4. Added proper TypedDict types for AIMetadata
|
|
5. Added embedding dimension configuration
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import List
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from config.ai_providers import (AIProviderType, TaskType,
|
|
get_embedding_dimension)
|
|
from config.settings import Settings
|
|
from core.ai_manager import (AIMetadata, AIProviderManager, AIResponse,
|
|
GroqProvider, OpenAIProvider, TranscriptionResult,
|
|
TranscriptionSegment)
|
|
|
|
|
|
class TestGroqProviderEmbeddingFix:
|
|
"""Test suite for GroqProvider embedding fix."""
|
|
|
|
@pytest.fixture
|
|
def mock_provider_config(self):
|
|
"""Create mock Groq provider config."""
|
|
config = MagicMock()
|
|
config.name = "Groq"
|
|
config.rate_limit_rpm = 30
|
|
return config
|
|
|
|
@pytest.fixture
|
|
def mock_settings(self):
|
|
"""Create mock settings with Groq API key."""
|
|
settings = MagicMock(spec=Settings)
|
|
settings.groq_api_key = "test_groq_key"
|
|
return settings
|
|
|
|
@pytest.fixture
|
|
def groq_provider(self, mock_provider_config, mock_settings):
|
|
"""Create GroqProvider instance."""
|
|
return GroqProvider(mock_provider_config, mock_settings)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_groq_embedding_raises_not_implemented_error(self, groq_provider):
|
|
"""Test that GroqProvider.generate_embedding raises NotImplementedError."""
|
|
with pytest.raises(NotImplementedError) as exc_info:
|
|
await groq_provider.generate_embedding("test text")
|
|
|
|
error_message = str(exc_info.value)
|
|
assert "Groq does not provide embedding APIs" in error_message
|
|
assert "Use OpenAI or another provider for embeddings" in error_message
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_groq_embedding_error_message_is_descriptive(self, groq_provider):
|
|
"""Test that the error message properly indicates Groq doesn't support embeddings."""
|
|
with pytest.raises(
|
|
NotImplementedError,
|
|
match="Groq does not provide embedding APIs.*Use OpenAI.*embeddings",
|
|
):
|
|
await groq_provider.generate_embedding("test text")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_groq_other_methods_still_work(self, groq_provider):
|
|
"""Test that other methods in GroqProvider still work properly."""
|
|
# Test transcription still works
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.text = "Transcribed with Groq"
|
|
mock_response.language = "en"
|
|
mock_response.duration = 10.0
|
|
mock_response.segments = []
|
|
mock_client.audio.transcriptions.create.return_value = mock_response
|
|
groq_provider.client = mock_client
|
|
groq_provider._initialized = True
|
|
|
|
with patch("core.ai_manager.get_model_config") as mock_get_model:
|
|
mock_model = MagicMock()
|
|
mock_model.name = "whisper-large-v3"
|
|
mock_get_model.return_value = mock_model
|
|
|
|
result = await groq_provider.transcribe_audio(b"fake_audio")
|
|
assert result.text == "Transcribed with Groq"
|
|
assert result.provider == "groq"
|
|
|
|
# Test text generation still works
|
|
mock_usage = MagicMock()
|
|
mock_usage.total_tokens = 100
|
|
mock_usage.prompt_tokens = 50
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.content = "Groq text response"
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = mock_usage
|
|
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
|
|
with patch("core.ai_manager.get_model_config") as mock_get_model:
|
|
mock_model = MagicMock()
|
|
mock_model.name = "llama3-70b-8192"
|
|
mock_model.max_tokens = 1000
|
|
mock_model.temperature = 0.3
|
|
mock_model.top_p = 1.0
|
|
mock_model.cost_per_1k_tokens = 0.0008
|
|
mock_get_model.return_value = mock_model
|
|
|
|
result = await groq_provider.generate_text("test prompt", TaskType.ANALYSIS)
|
|
assert result.content == "Groq text response"
|
|
assert result.provider == "groq"
|
|
assert result.success is True
|
|
|
|
|
|
class TestLatencyLogging:
|
|
"""Test suite for latency logging fixes."""
|
|
|
|
@pytest.fixture
|
|
def mock_openai_provider_config(self):
|
|
"""Create mock OpenAI provider config."""
|
|
config = MagicMock()
|
|
config.name = "OpenAI"
|
|
config.rate_limit_rpm = 500
|
|
return config
|
|
|
|
@pytest.fixture
|
|
def mock_groq_provider_config(self):
|
|
"""Create mock Groq provider config."""
|
|
config = MagicMock()
|
|
config.name = "Groq"
|
|
config.rate_limit_rpm = 30
|
|
return config
|
|
|
|
@pytest.fixture
|
|
def mock_settings(self):
|
|
"""Create mock settings with API keys."""
|
|
settings = MagicMock(spec=Settings)
|
|
settings.openai_api_key = "test_openai_key"
|
|
settings.groq_api_key = "test_groq_key"
|
|
return settings
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_transcription_logs_latency(
|
|
self, mock_openai_provider_config, mock_settings, caplog
|
|
):
|
|
"""Test that OpenAI transcription properly logs latency."""
|
|
openai_provider = OpenAIProvider(mock_openai_provider_config, mock_settings)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = "Transcribed text"
|
|
mock_response.language = "en"
|
|
mock_response.duration = 10.5
|
|
mock_response.segments = []
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.audio.transcriptions.create.return_value = mock_response
|
|
openai_provider.client = mock_client
|
|
openai_provider._initialized = True
|
|
|
|
with (
|
|
patch("core.ai_manager.get_model_config") as mock_get_model,
|
|
caplog.at_level(logging.DEBUG),
|
|
):
|
|
mock_model = MagicMock()
|
|
mock_model.name = "whisper-1"
|
|
mock_get_model.return_value = mock_model
|
|
|
|
# Add a small delay to ensure latency > 0
|
|
async def mock_create(*args, **kwargs):
|
|
await asyncio.sleep(0.01)
|
|
return mock_response
|
|
|
|
mock_client.audio.transcriptions.create = mock_create
|
|
|
|
await openai_provider.transcribe_audio(b"fake_audio")
|
|
|
|
# Check that latency was logged
|
|
log_messages = [record.message for record in caplog.records]
|
|
latency_logs = [
|
|
msg
|
|
for msg in log_messages
|
|
if "OpenAI transcription completed" in msg and "s" in msg
|
|
]
|
|
|
|
assert (
|
|
len(latency_logs) > 0
|
|
), f"Expected latency log message. Got: {log_messages}"
|
|
|
|
# Verify latency value is reasonable (> 0)
|
|
for log_msg in latency_logs:
|
|
# Extract latency value from log message like "OpenAI transcription completed in 0.01s"
|
|
if "completed in" in log_msg and "s" in log_msg:
|
|
latency_str = log_msg.split("completed in ")[1].split("s")[0]
|
|
latency_value = float(latency_str)
|
|
assert latency_value > 0, f"Latency should be > 0, got: {latency_value}"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_groq_transcription_logs_latency(
|
|
self, mock_groq_provider_config, mock_settings, caplog
|
|
):
|
|
"""Test that Groq transcription properly logs latency."""
|
|
groq_provider = GroqProvider(mock_groq_provider_config, mock_settings)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = "Groq transcribed text"
|
|
mock_response.language = "en"
|
|
mock_response.duration = 8.0
|
|
mock_response.segments = []
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.audio.transcriptions.create.return_value = mock_response
|
|
groq_provider.client = mock_client
|
|
groq_provider._initialized = True
|
|
|
|
with (
|
|
patch("core.ai_manager.get_model_config") as mock_get_model,
|
|
caplog.at_level(logging.DEBUG),
|
|
):
|
|
mock_model = MagicMock()
|
|
mock_model.name = "whisper-large-v3"
|
|
mock_get_model.return_value = mock_model
|
|
|
|
# Add a small delay to ensure latency > 0
|
|
async def mock_create(*args, **kwargs):
|
|
await asyncio.sleep(0.02)
|
|
return mock_response
|
|
|
|
mock_client.audio.transcriptions.create = mock_create
|
|
|
|
await groq_provider.transcribe_audio(b"fake_groq_audio")
|
|
|
|
# Check that latency was logged
|
|
log_messages = [record.message for record in caplog.records]
|
|
latency_logs = [
|
|
msg
|
|
for msg in log_messages
|
|
if "Groq transcription completed" in msg and "s" in msg
|
|
]
|
|
|
|
assert (
|
|
len(latency_logs) > 0
|
|
), f"Expected latency log message. Got: {log_messages}"
|
|
|
|
# Verify latency value is reasonable (> 0)
|
|
for log_msg in latency_logs:
|
|
if "completed in" in log_msg and "s" in log_msg:
|
|
latency_str = log_msg.split("completed in ")[1].split("s")[0]
|
|
latency_value = float(latency_str)
|
|
assert latency_value > 0, f"Latency should be > 0, got: {latency_value}"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_latency_values_are_reasonable(
|
|
self, mock_openai_provider_config, mock_settings
|
|
):
|
|
"""Test that latency values are reasonable (> 0 and not absurdly high)."""
|
|
openai_provider = OpenAIProvider(mock_openai_provider_config, mock_settings)
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.text = "Test transcription"
|
|
mock_response.language = "en"
|
|
mock_response.duration = 5.0
|
|
mock_response.segments = []
|
|
|
|
# Create a mock that introduces a controlled delay
|
|
async def mock_transcribe(*args, **kwargs):
|
|
await asyncio.sleep(0.05) # 50ms delay
|
|
return mock_response
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.audio.transcriptions.create = mock_transcribe
|
|
openai_provider.client = mock_client
|
|
openai_provider._initialized = True
|
|
|
|
with patch("core.ai_manager.get_model_config") as mock_get_model:
|
|
mock_model = MagicMock()
|
|
mock_model.name = "whisper-1"
|
|
mock_get_model.return_value = mock_model
|
|
|
|
start_time = time.time()
|
|
result = await openai_provider.transcribe_audio(b"fake_audio")
|
|
end_time = time.time()
|
|
|
|
actual_duration = end_time - start_time
|
|
|
|
# Verify the transcription result contains expected text
|
|
assert result.text == "Test transcription"
|
|
|
|
# Verify actual duration is reasonable (at least our 50ms delay)
|
|
assert (
|
|
actual_duration >= 0.05
|
|
), f"Actual duration {actual_duration} should be >= 0.05s"
|
|
assert (
|
|
actual_duration < 1.0
|
|
), f"Actual duration {actual_duration} should be < 1.0s for test"
|
|
|
|
|
|
class TestPydanticModelUpdates:
|
|
"""Test suite for Pydantic model updates (.model_dump() vs .dict())."""
|
|
|
|
@pytest.fixture
|
|
def mock_provider_config(self):
|
|
"""Create mock provider config."""
|
|
config = MagicMock()
|
|
config.name = "OpenAI"
|
|
config.rate_limit_rpm = 500
|
|
return config
|
|
|
|
@pytest.fixture
|
|
def mock_settings(self):
|
|
"""Create mock settings."""
|
|
settings = MagicMock(spec=Settings)
|
|
settings.openai_api_key = "test_openai_key"
|
|
return settings
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ai_response_uses_model_dump(
|
|
self, mock_provider_config, mock_settings
|
|
):
|
|
"""Test that responses use .model_dump() instead of .dict()."""
|
|
openai_provider = OpenAIProvider(mock_provider_config, mock_settings)
|
|
|
|
# Mock usage object with model_dump method
|
|
mock_usage = MagicMock()
|
|
mock_usage.total_tokens = 150
|
|
mock_usage.model_dump.return_value = {
|
|
"prompt_tokens": 50,
|
|
"completion_tokens": 100,
|
|
"total_tokens": 150,
|
|
}
|
|
# Ensure .dict() method doesn't exist or raises error
|
|
mock_usage.dict = MagicMock(
|
|
side_effect=AttributeError("'Usage' object has no attribute 'dict'")
|
|
)
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.content = "Generated response"
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = mock_usage
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create.return_value = mock_response
|
|
openai_provider.client = mock_client
|
|
openai_provider._initialized = True
|
|
|
|
with patch("core.ai_manager.get_model_config") as mock_get_model:
|
|
mock_model = MagicMock()
|
|
mock_model.name = "gpt-4"
|
|
mock_model.max_tokens = 1000
|
|
mock_model.temperature = 0.7
|
|
mock_model.top_p = 1.0
|
|
mock_model.frequency_penalty = 0.0
|
|
mock_model.presence_penalty = 0.0
|
|
mock_model.cost_per_1k_tokens = 0.03
|
|
mock_get_model.return_value = mock_model
|
|
|
|
result = await openai_provider.generate_text(
|
|
"Test prompt", TaskType.ANALYSIS
|
|
)
|
|
|
|
# Verify the result was created successfully
|
|
assert result.content == "Generated response"
|
|
assert result.success is True
|
|
|
|
# Verify model_dump was called, not dict
|
|
mock_usage.model_dump.assert_called_once()
|
|
|
|
# Verify metadata structure contains proper usage data
|
|
assert result.metadata is not None
|
|
assert "usage" in result.metadata
|
|
assert result.metadata["usage"]["total_tokens"] == 150
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_metadata_structure_validation(
|
|
self, mock_provider_config, mock_settings
|
|
):
|
|
"""Test metadata structure with proper TypedDict validation."""
|
|
|
|
# Create proper AIMetadata structure
|
|
test_metadata: AIMetadata = {
|
|
"usage": {"total_tokens": 100},
|
|
"input_tokens": 40,
|
|
"output_tokens": 60,
|
|
"prompt_tokens": 40,
|
|
"completion_tokens": 60,
|
|
}
|
|
|
|
# Create AIResponse with proper metadata
|
|
response = AIResponse(
|
|
content="Test response",
|
|
provider="openai",
|
|
model="gpt-4",
|
|
tokens_used=100,
|
|
cost=0.003,
|
|
latency=0.5,
|
|
success=True,
|
|
metadata=test_metadata,
|
|
)
|
|
|
|
# Verify metadata structure
|
|
assert response.metadata is not None
|
|
assert "usage" in response.metadata
|
|
assert "input_tokens" in response.metadata
|
|
assert "output_tokens" in response.metadata
|
|
assert response.metadata["usage"]["total_tokens"] == 100
|
|
assert response.metadata["input_tokens"] == 40
|
|
assert response.metadata["output_tokens"] == 60
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transcription_segment_structure(self):
|
|
"""Test TranscriptionSegment TypedDict structure validation."""
|
|
# Create proper TranscriptionSegment structure
|
|
test_segment: TranscriptionSegment = {
|
|
"id": 1,
|
|
"seek": 0,
|
|
"start": 0.0,
|
|
"end": 2.5,
|
|
"text": "Hello world",
|
|
"tokens": [1, 2, 3],
|
|
"temperature": 0.0,
|
|
"avg_logprob": -0.5,
|
|
"compression_ratio": 1.8,
|
|
"no_speech_prob": 0.1,
|
|
}
|
|
|
|
# Verify structure
|
|
assert test_segment["id"] == 1
|
|
assert test_segment["text"] == "Hello world"
|
|
assert isinstance(test_segment["tokens"], list)
|
|
assert test_segment["start"] < test_segment["end"]
|
|
|
|
|
|
class TestTypeSafety:
|
|
"""Test suite for type safety improvements."""
|
|
|
|
def test_ai_metadata_typed_dict_structure(self):
|
|
"""Test AIMetadata TypedDict structure."""
|
|
# Test complete AIMetadata
|
|
complete_metadata: AIMetadata = {
|
|
"usage": {"total_tokens": 150, "prompt_tokens": 50},
|
|
"input_tokens": 50,
|
|
"output_tokens": 100,
|
|
"prompt_tokens": 50,
|
|
"completion_tokens": 100,
|
|
}
|
|
|
|
assert complete_metadata["usage"]["total_tokens"] == 150
|
|
assert complete_metadata["input_tokens"] == 50
|
|
|
|
# Test partial AIMetadata (total=False allows optional fields)
|
|
partial_metadata: AIMetadata = {"usage": {"total_tokens": 100}}
|
|
|
|
assert partial_metadata["usage"]["total_tokens"] == 100
|
|
|
|
def test_transcription_segment_typed_dict_structure(self):
|
|
"""Test TranscriptionSegment TypedDict structure."""
|
|
# Test complete TranscriptionSegment
|
|
complete_segment: TranscriptionSegment = {
|
|
"id": 1,
|
|
"seek": 0,
|
|
"start": 0.0,
|
|
"end": 2.5,
|
|
"text": "Test transcription",
|
|
"tokens": [1, 2, 3, 4],
|
|
"temperature": 0.0,
|
|
"avg_logprob": -0.3,
|
|
"compression_ratio": 2.1,
|
|
"no_speech_prob": 0.05,
|
|
}
|
|
|
|
assert complete_segment["id"] == 1
|
|
assert complete_segment["text"] == "Test transcription"
|
|
assert len(complete_segment["tokens"]) == 4
|
|
|
|
# Test partial TranscriptionSegment (total=False allows optional fields)
|
|
partial_segment: TranscriptionSegment = {"text": "Partial segment"}
|
|
|
|
assert partial_segment["text"] == "Partial segment"
|
|
|
|
def test_ai_response_proper_typing(self):
|
|
"""Test AIResponse class uses proper typing."""
|
|
metadata: AIMetadata = {
|
|
"usage": {"total_tokens": 100},
|
|
"input_tokens": 40,
|
|
"output_tokens": 60,
|
|
}
|
|
|
|
response = AIResponse(
|
|
content="Test response",
|
|
provider="openai",
|
|
model="gpt-4",
|
|
tokens_used=100,
|
|
cost=0.003,
|
|
latency=0.5,
|
|
success=True,
|
|
metadata=metadata,
|
|
)
|
|
|
|
# Verify types
|
|
assert isinstance(response.content, str)
|
|
assert isinstance(response.provider, str)
|
|
assert isinstance(response.model, str)
|
|
assert isinstance(response.tokens_used, int)
|
|
assert isinstance(response.cost, float)
|
|
assert isinstance(response.latency, float)
|
|
assert isinstance(response.success, bool)
|
|
assert response.metadata is not None
|
|
assert isinstance(response.metadata, dict)
|
|
|
|
def test_transcription_result_proper_typing(self):
|
|
"""Test TranscriptionResult class uses proper typing."""
|
|
segments: List[TranscriptionSegment] = [
|
|
{"id": 1, "text": "First segment", "start": 0.0, "end": 2.0}
|
|
]
|
|
|
|
result = TranscriptionResult(
|
|
text="Complete transcription",
|
|
language="en",
|
|
confidence=0.95,
|
|
duration=10.5,
|
|
segments=segments,
|
|
provider="openai",
|
|
model="whisper-1",
|
|
)
|
|
|
|
# Verify types
|
|
assert isinstance(result.text, str)
|
|
assert isinstance(result.language, str)
|
|
assert isinstance(result.confidence, float)
|
|
assert isinstance(result.duration, float)
|
|
assert isinstance(result.segments, list)
|
|
assert isinstance(result.provider, str)
|
|
assert isinstance(result.model, str)
|
|
|
|
|
|
class TestEmbeddingDimensions:
|
|
"""Test suite for embedding dimension configuration."""
|
|
|
|
def test_get_embedding_dimension_function(self):
|
|
"""Test get_embedding_dimension function with various models."""
|
|
# Test OpenAI models
|
|
assert get_embedding_dimension("text-embedding-3-small") == 1536
|
|
assert get_embedding_dimension("text-embedding-3-large") == 3072
|
|
assert get_embedding_dimension("text-embedding-ada-002") == 1536
|
|
|
|
# Test other models
|
|
assert get_embedding_dimension("nomic-embed-text") == 768
|
|
assert get_embedding_dimension("sentence-transformers/all-MiniLM-L6-v2") == 384
|
|
assert get_embedding_dimension("sentence-transformers/all-mpnet-base-v2") == 768
|
|
|
|
# Test unknown model returns default
|
|
assert get_embedding_dimension("unknown-model") == 1536
|
|
|
|
def test_embedding_dimension_retrieval_for_openai_models(self):
|
|
"""Test embedding dimension retrieval for OpenAI models."""
|
|
# Test small embedding model
|
|
dim_small = get_embedding_dimension("text-embedding-3-small")
|
|
assert dim_small == 1536
|
|
|
|
# Test large embedding model
|
|
dim_large = get_embedding_dimension("text-embedding-3-large")
|
|
assert dim_large == 3072
|
|
|
|
# Test legacy model
|
|
dim_ada = get_embedding_dimension("text-embedding-ada-002")
|
|
assert dim_ada == 1536
|
|
|
|
def test_default_dimension_fallback(self):
|
|
"""Test default dimension fallback for unknown models."""
|
|
# Test with completely unknown model names
|
|
assert get_embedding_dimension("random-model-name") == 1536
|
|
assert get_embedding_dimension("") == 1536
|
|
assert get_embedding_dimension("non-existent-embedding-model") == 1536
|
|
|
|
# Test that default is reasonable (OpenAI standard)
|
|
default_dim = get_embedding_dimension("unknown")
|
|
assert default_dim == 1536
|
|
assert isinstance(default_dim, int)
|
|
assert default_dim > 0
|
|
|
|
@pytest.mark.parametrize(
|
|
"model_name,expected_dimension",
|
|
[
|
|
("text-embedding-3-small", 1536),
|
|
("text-embedding-3-large", 3072),
|
|
("text-embedding-ada-002", 1536),
|
|
("nomic-embed-text", 768),
|
|
("sentence-transformers/all-MiniLM-L6-v2", 384),
|
|
("sentence-transformers/all-mpnet-base-v2", 768),
|
|
("unknown-model", 1536), # default fallback
|
|
],
|
|
)
|
|
def test_embedding_dimensions_parametrized(
|
|
self, model_name: str, expected_dimension: int
|
|
):
|
|
"""Test embedding dimensions with parametrized inputs."""
|
|
actual_dimension = get_embedding_dimension(model_name)
|
|
assert actual_dimension == expected_dimension
|
|
assert isinstance(actual_dimension, int)
|
|
assert actual_dimension > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_generation_uses_correct_dimensions(self):
|
|
"""Test that embedding generation uses correct dimensions for different models."""
|
|
# Test with OpenAI provider
|
|
mock_provider_config = MagicMock()
|
|
mock_provider_config.name = "OpenAI"
|
|
mock_provider_config.rate_limit_rpm = 500
|
|
|
|
mock_settings = MagicMock(spec=Settings)
|
|
mock_settings.openai_api_key = "test_key"
|
|
|
|
openai_provider = OpenAIProvider(mock_provider_config, mock_settings)
|
|
|
|
# Mock different embedding models
|
|
test_cases = [
|
|
("text-embedding-3-small", 1536),
|
|
("text-embedding-3-large", 3072),
|
|
("text-embedding-ada-002", 1536),
|
|
]
|
|
|
|
for model_name, expected_dim in test_cases:
|
|
mock_embedding_data = MagicMock()
|
|
mock_embedding_data.embedding = [0.1] * expected_dim
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.data = [mock_embedding_data]
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.embeddings.create.return_value = mock_response
|
|
openai_provider.client = mock_client
|
|
openai_provider._initialized = True
|
|
|
|
with patch("core.ai_manager.get_model_config") as mock_get_model:
|
|
mock_model = MagicMock()
|
|
mock_model.name = model_name
|
|
mock_get_model.return_value = mock_model
|
|
|
|
result = await openai_provider.generate_embedding("Test text")
|
|
|
|
# Verify the embedding has the expected dimensions
|
|
assert len(result) == expected_dim
|
|
assert all(isinstance(val, float) for val in result)
|
|
|
|
# Verify the model config was called correctly
|
|
mock_get_model.assert_called_with(
|
|
AIProviderType.OPENAI, TaskType.EMBEDDING
|
|
)
|
|
|
|
|
|
class TestAIManagerFixesIntegration:
|
|
"""Integration tests for all AI Manager fixes combined."""
|
|
|
|
@pytest.fixture
|
|
def mock_settings(self):
|
|
"""Create comprehensive mock settings."""
|
|
settings = MagicMock(spec=Settings)
|
|
settings.openai_api_key = "test_openai_key"
|
|
settings.anthropic_api_key = "test_anthropic_key"
|
|
settings.groq_api_key = "test_groq_key"
|
|
settings.ollama_base_url = "http://localhost:11434"
|
|
return settings
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_provider_fallback_with_embedding_fix(self, mock_settings):
|
|
"""Test provider fallback properly handles Groq embedding NotImplementedError."""
|
|
manager = AIProviderManager(mock_settings)
|
|
|
|
# Mock Groq provider (doesn't support embeddings)
|
|
mock_groq_provider = AsyncMock()
|
|
mock_groq_provider.generate_embedding.side_effect = NotImplementedError(
|
|
"Groq does not provide embedding APIs. Use OpenAI or another provider for embeddings."
|
|
)
|
|
|
|
# Mock OpenAI provider (supports embeddings)
|
|
mock_openai_provider = AsyncMock()
|
|
mock_openai_provider.generate_embedding.return_value = [0.1] * 1536
|
|
|
|
manager.providers = {
|
|
AIProviderType.GROQ: mock_groq_provider,
|
|
AIProviderType.OPENAI: mock_openai_provider,
|
|
}
|
|
manager.active_providers = [AIProviderType.GROQ, AIProviderType.OPENAI]
|
|
manager._initialized = True
|
|
|
|
with patch(
|
|
"core.ai_manager.get_preferred_providers",
|
|
return_value=[AIProviderType.GROQ, AIProviderType.OPENAI],
|
|
):
|
|
result = await manager.generate_embedding("Test text")
|
|
|
|
# Should successfully fall back to OpenAI
|
|
assert result == [0.1] * 1536
|
|
assert len(result) == 1536
|
|
|
|
# Verify Groq was tried first, then OpenAI succeeded
|
|
mock_groq_provider.generate_embedding.assert_called_once_with("Test text")
|
|
mock_openai_provider.generate_embedding.assert_called_once_with("Test text")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_latency_logging_with_model_dump_usage(self, mock_settings, caplog):
|
|
"""Test that latency logging works with .model_dump() usage metadata."""
|
|
mock_provider_config = MagicMock()
|
|
mock_provider_config.name = "OpenAI"
|
|
mock_provider_config.rate_limit_rpm = 500
|
|
|
|
openai_provider = OpenAIProvider(mock_provider_config, mock_settings)
|
|
|
|
# Mock usage with model_dump method
|
|
mock_usage = MagicMock()
|
|
mock_usage.total_tokens = 200
|
|
mock_usage.model_dump.return_value = {
|
|
"prompt_tokens": 75,
|
|
"completion_tokens": 125,
|
|
"total_tokens": 200,
|
|
}
|
|
|
|
mock_message = MagicMock()
|
|
mock_message.content = "Response with proper model_dump usage"
|
|
|
|
mock_choice = MagicMock()
|
|
mock_choice.message = mock_message
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.choices = [mock_choice]
|
|
mock_response.usage = mock_usage
|
|
|
|
# Add delay to generate measurable latency
|
|
async def mock_create(*args, **kwargs):
|
|
await asyncio.sleep(0.01)
|
|
return mock_response
|
|
|
|
mock_client = AsyncMock()
|
|
mock_client.chat.completions.create = mock_create
|
|
openai_provider.client = mock_client
|
|
openai_provider._initialized = True
|
|
|
|
with (
|
|
patch("core.ai_manager.get_model_config") as mock_get_model,
|
|
caplog.at_level(logging.DEBUG),
|
|
):
|
|
mock_model = MagicMock()
|
|
mock_model.name = "gpt-4"
|
|
mock_model.max_tokens = 1000
|
|
mock_model.temperature = 0.7
|
|
mock_model.top_p = 1.0
|
|
mock_model.frequency_penalty = 0.0
|
|
mock_model.presence_penalty = 0.0
|
|
mock_model.cost_per_1k_tokens = 0.03
|
|
mock_get_model.return_value = mock_model
|
|
|
|
result = await openai_provider.generate_text(
|
|
"Test prompt", TaskType.ANALYSIS
|
|
)
|
|
|
|
# Verify successful response with proper model_dump usage
|
|
assert result.content == "Response with proper model_dump usage"
|
|
assert result.success is True
|
|
assert result.latency > 0
|
|
|
|
# Verify model_dump was used correctly
|
|
mock_usage.model_dump.assert_called_once()
|
|
|
|
# Verify metadata contains usage from model_dump
|
|
assert result.metadata is not None
|
|
assert "usage" in result.metadata
|
|
assert result.metadata["usage"]["total_tokens"] == 200
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_comprehensive_fixes_work_together(self, mock_settings, caplog):
|
|
"""Test that all fixes work together in a comprehensive scenario."""
|
|
manager = AIProviderManager(mock_settings)
|
|
|
|
# Setup multiple providers with different capabilities
|
|
|
|
# 1. Groq provider - supports transcription but NOT embeddings
|
|
mock_groq_provider = AsyncMock()
|
|
|
|
# Groq transcription with latency logging
|
|
mock_transcription_response = MagicMock()
|
|
mock_transcription_response.text = "Groq transcription"
|
|
mock_transcription_response.language = "en"
|
|
mock_transcription_response.duration = 8.0
|
|
mock_transcription_response.segments = []
|
|
|
|
async def mock_groq_transcribe(*args, **kwargs):
|
|
await asyncio.sleep(0.01) # Add latency
|
|
return mock_transcription_response
|
|
|
|
mock_groq_provider.transcribe_audio = mock_groq_transcribe
|
|
|
|
# Groq embedding raises NotImplementedError
|
|
mock_groq_provider.generate_embedding.side_effect = NotImplementedError(
|
|
"Groq does not provide embedding APIs. Use OpenAI or another provider for embeddings."
|
|
)
|
|
|
|
# 2. OpenAI provider - supports both transcription and embeddings
|
|
mock_openai_provider = AsyncMock()
|
|
|
|
# OpenAI embedding with proper dimensions
|
|
mock_openai_provider.generate_embedding.return_value = [0.1] * 1536
|
|
|
|
# OpenAI text generation with .model_dump()
|
|
mock_usage = MagicMock()
|
|
mock_usage.total_tokens = 150
|
|
mock_usage.model_dump.return_value = {"total_tokens": 150}
|
|
|
|
mock_text_response = AIResponse(
|
|
content="OpenAI response",
|
|
provider="openai",
|
|
model="gpt-4",
|
|
tokens_used=150,
|
|
cost=0.0045,
|
|
latency=0.25,
|
|
success=True,
|
|
metadata={"usage": {"total_tokens": 150}},
|
|
)
|
|
|
|
mock_openai_provider.generate_text.return_value = mock_text_response
|
|
|
|
# Setup manager
|
|
manager.providers = {
|
|
AIProviderType.GROQ: mock_groq_provider,
|
|
AIProviderType.OPENAI: mock_openai_provider,
|
|
}
|
|
manager.active_providers = [AIProviderType.GROQ, AIProviderType.OPENAI]
|
|
manager._initialized = True
|
|
|
|
with (
|
|
patch("core.ai_manager.get_preferred_providers") as mock_get_preferred,
|
|
caplog.at_level(logging.DEBUG),
|
|
):
|
|
|
|
# Test 1: Transcription works with latency logging
|
|
mock_get_preferred.return_value = [AIProviderType.GROQ]
|
|
transcription_result = await manager.transcribe(b"audio_data")
|
|
|
|
assert transcription_result.text == "Groq transcription"
|
|
|
|
# Verify latency was logged (implementation would need to be in the provider)
|
|
# This tests the integration, actual latency logging is tested in individual tests
|
|
|
|
# Test 2: Embedding falls back from Groq to OpenAI
|
|
mock_get_preferred.return_value = [
|
|
AIProviderType.GROQ,
|
|
AIProviderType.OPENAI,
|
|
]
|
|
embedding_result = await manager.generate_embedding("Test text")
|
|
|
|
assert len(embedding_result) == 1536
|
|
assert embedding_result == [0.1] * 1536
|
|
|
|
# Verify Groq was tried first but failed with NotImplementedError
|
|
mock_groq_provider.generate_embedding.assert_called_once()
|
|
mock_openai_provider.generate_embedding.assert_called_once()
|
|
|
|
# Test 3: Text generation uses .model_dump()
|
|
mock_get_preferred.return_value = [AIProviderType.OPENAI]
|
|
text_result = await manager.analyze_quote("Test quote analysis")
|
|
|
|
assert text_result.content == "OpenAI response"
|
|
assert text_result.success is True
|
|
assert "usage" in text_result.metadata
|
|
assert text_result.metadata["usage"]["total_tokens"] == 150
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|