- Deleted .env.example file as it is no longer needed. - Added .gitignore to manage ignored files and directories. - Introduced CLAUDE.md for AI provider integration documentation. - Created dev.sh for development setup and scripts. - Updated Dockerfile and Dockerfile.production for improved build processes. - Added multiple test files and directories for comprehensive testing. - Introduced new utility and service files for enhanced functionality. - Organized codebase with new directories and files for better maintainability.
389 lines
13 KiB
Python
389 lines
13 KiB
Python
"""
|
|
Unit Tests for AI Manager
|
|
|
|
Tests for AI provider management, text generation, embeddings, and fallback mechanisms.
|
|
"""
|
|
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from core.ai_manager import AIProviderManager
|
|
from tests.conftest import TestConfig
|
|
|
|
|
|
class TestAIProviderManager:
|
|
"""Unit tests for AIProviderManager class"""
|
|
|
|
@pytest.fixture
|
|
async def ai_manager(self):
|
|
"""Create AI manager instance for testing"""
|
|
ai_manager = AIProviderManager()
|
|
await ai_manager.initialize()
|
|
yield ai_manager
|
|
await ai_manager.close()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self, ai_manager):
|
|
"""Test AI manager initialization"""
|
|
assert ai_manager._initialized
|
|
assert len(ai_manager.providers) > 0
|
|
assert ai_manager.current_provider is not None
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_provider_registration(self, ai_manager):
|
|
"""Test AI provider registration"""
|
|
# Mock provider
|
|
mock_provider = MagicMock()
|
|
mock_provider.name = "test_provider"
|
|
mock_provider.priority = 1
|
|
|
|
ai_manager.register_provider(mock_provider)
|
|
|
|
assert "test_provider" in ai_manager.providers
|
|
assert ai_manager.providers["test_provider"] == mock_provider
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_generate_text_success(self, ai_manager):
|
|
"""Test successful text generation"""
|
|
with patch.object(ai_manager, "current_provider") as mock_provider:
|
|
mock_provider.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
|
|
|
|
result = await ai_manager.generate_text(
|
|
prompt="Test prompt", max_tokens=100
|
|
)
|
|
|
|
assert result == TestConfig.MOCK_AI_RESPONSE
|
|
mock_provider.generate_text.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_generate_text_with_fallback(self, ai_manager):
|
|
"""Test text generation with provider fallback"""
|
|
# Mock primary provider failure
|
|
mock_primary = MagicMock()
|
|
mock_primary.name = "primary"
|
|
mock_primary.priority = 1
|
|
mock_primary.generate_text.side_effect = Exception("API Error")
|
|
|
|
# Mock fallback provider success
|
|
mock_fallback = MagicMock()
|
|
mock_fallback.name = "fallback"
|
|
mock_fallback.priority = 2
|
|
mock_fallback.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
|
|
|
|
ai_manager.providers = {"primary": mock_primary, "fallback": mock_fallback}
|
|
ai_manager.current_provider = mock_primary
|
|
|
|
result = await ai_manager.generate_text("Test prompt")
|
|
|
|
# Should use fallback provider
|
|
assert result == TestConfig.MOCK_AI_RESPONSE
|
|
mock_primary.generate_text.assert_called_once()
|
|
mock_fallback.generate_text.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_success(self, ai_manager):
|
|
"""Test successful embedding generation"""
|
|
mock_embedding = [0.1, 0.2, 0.3] * 128 # 384-dim embedding
|
|
|
|
with patch.object(ai_manager, "current_provider") as mock_provider:
|
|
mock_provider.generate_embedding.return_value = mock_embedding
|
|
|
|
result = await ai_manager.generate_embedding("Test text")
|
|
|
|
assert result == mock_embedding
|
|
assert len(result) == 384
|
|
mock_provider.generate_embedding.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_provider_health_check(self, ai_manager):
|
|
"""Test provider health checking"""
|
|
with patch.object(ai_manager, "current_provider") as mock_provider:
|
|
mock_provider.check_health.return_value = {
|
|
"healthy": True,
|
|
"response_time": 0.5,
|
|
"rate_limit_remaining": 100,
|
|
}
|
|
|
|
health = await ai_manager.check_provider_health()
|
|
|
|
assert health["healthy"] is True
|
|
assert "response_time" in health
|
|
mock_provider.check_health.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_handling(self, ai_manager):
|
|
"""Test rate limit handling"""
|
|
mock_provider = MagicMock()
|
|
mock_provider.name = "rate_limited"
|
|
mock_provider.is_rate_limited.return_value = True
|
|
mock_provider.get_rate_limit_reset_time.return_value = datetime.utcnow()
|
|
|
|
ai_manager.current_provider = mock_provider
|
|
|
|
# Should handle rate limiting gracefully
|
|
with pytest.raises(Exception) as exc_info:
|
|
await ai_manager.generate_text("Test prompt")
|
|
|
|
assert "rate limit" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_provider_switching(self, ai_manager):
|
|
"""Test automatic provider switching"""
|
|
# Create multiple mock providers
|
|
provider1 = MagicMock()
|
|
provider1.name = "provider1"
|
|
provider1.priority = 1
|
|
provider1.is_healthy = False
|
|
|
|
provider2 = MagicMock()
|
|
provider2.name = "provider2"
|
|
provider2.priority = 2
|
|
provider2.is_healthy = True
|
|
|
|
ai_manager.providers = {"provider1": provider1, "provider2": provider2}
|
|
|
|
# Should switch to healthy provider
|
|
await ai_manager._switch_to_best_provider()
|
|
|
|
assert ai_manager.current_provider == provider2
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_requests(self, ai_manager):
|
|
"""Test handling concurrent AI requests"""
|
|
with patch.object(ai_manager, "current_provider") as mock_provider:
|
|
mock_provider.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
|
|
|
|
# Submit multiple concurrent requests
|
|
import asyncio
|
|
|
|
tasks = [ai_manager.generate_text(f"Prompt {i}") for i in range(5)]
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
assert len(results) == 5
|
|
for result in results:
|
|
assert result == TestConfig.MOCK_AI_RESPONSE
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_provider_metrics(self, ai_manager):
|
|
"""Test provider performance metrics"""
|
|
with patch.object(ai_manager, "current_provider") as mock_provider:
|
|
mock_provider.get_metrics.return_value = {
|
|
"total_requests": 100,
|
|
"successful_requests": 95,
|
|
"failed_requests": 5,
|
|
"average_response_time": 1.2,
|
|
"tokens_used": 5000,
|
|
}
|
|
|
|
metrics = await ai_manager.get_provider_metrics()
|
|
|
|
assert metrics["total_requests"] == 100
|
|
assert metrics["successful_requests"] == 95
|
|
assert "average_response_time" in metrics
|
|
|
|
|
|
class TestOpenAIProvider:
|
|
"""Unit tests for OpenAI provider"""
|
|
|
|
@pytest.fixture
|
|
def openai_provider(self):
|
|
"""Create OpenAI provider for testing"""
|
|
with patch("openai.AsyncOpenAI"):
|
|
from core.ai_providers.openai_provider import OpenAIProvider
|
|
|
|
provider = OpenAIProvider(api_key="test_key")
|
|
return provider
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_text_generation(self, openai_provider):
|
|
"""Test OpenAI text generation"""
|
|
with patch.object(
|
|
openai_provider.client.chat.completions, "create"
|
|
) as mock_create:
|
|
mock_create.return_value = MagicMock(
|
|
choices=[MagicMock(message=MagicMock(content="Test response"))]
|
|
)
|
|
|
|
result = await openai_provider.generate_text("Test prompt")
|
|
|
|
assert result.choices[0].message.content == "Test response"
|
|
mock_create.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_embedding_generation(self, openai_provider):
|
|
"""Test OpenAI embedding generation"""
|
|
mock_embedding = [0.1] * 1536 # OpenAI embedding dimension
|
|
|
|
with patch.object(openai_provider.client.embeddings, "create") as mock_create:
|
|
mock_create.return_value = MagicMock(
|
|
data=[MagicMock(embedding=mock_embedding)]
|
|
)
|
|
|
|
result = await openai_provider.generate_embedding("Test text")
|
|
|
|
assert result == mock_embedding
|
|
mock_create.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_error_handling(self, openai_provider):
|
|
"""Test OpenAI error handling"""
|
|
with patch.object(
|
|
openai_provider.client.chat.completions, "create"
|
|
) as mock_create:
|
|
mock_create.side_effect = Exception("OpenAI API Error")
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await openai_provider.generate_text("Test prompt")
|
|
|
|
assert "OpenAI API Error" in str(exc_info.value)
|
|
|
|
|
|
class TestAnthropicProvider:
|
|
"""Unit tests for Anthropic provider"""
|
|
|
|
@pytest.fixture
|
|
def anthropic_provider(self):
|
|
"""Create Anthropic provider for testing"""
|
|
with patch("anthropic.AsyncAnthropic"):
|
|
from core.ai_providers.anthropic_provider import AnthropicProvider
|
|
|
|
provider = AnthropicProvider(api_key="test_key")
|
|
return provider
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_text_generation(self, anthropic_provider):
|
|
"""Test Anthropic text generation"""
|
|
with patch.object(anthropic_provider.client.messages, "create") as mock_create:
|
|
mock_create.return_value = MagicMock(
|
|
content=[MagicMock(text="Test response")]
|
|
)
|
|
|
|
result = await anthropic_provider.generate_text("Test prompt")
|
|
|
|
assert result.content[0].text == "Test response"
|
|
mock_create.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_handling(self, anthropic_provider):
|
|
"""Test Anthropic rate limit handling"""
|
|
import anthropic
|
|
|
|
with patch.object(anthropic_provider.client.messages, "create") as mock_create:
|
|
mock_create.side_effect = anthropic.RateLimitError("Rate limit exceeded")
|
|
|
|
with pytest.raises(anthropic.RateLimitError):
|
|
await anthropic_provider.generate_text("Test prompt")
|
|
|
|
|
|
class TestLocalProvider:
|
|
"""Unit tests for local AI provider"""
|
|
|
|
@pytest.fixture
|
|
def local_provider(self):
|
|
"""Create local provider for testing"""
|
|
from core.ai_providers.local_provider import LocalProvider
|
|
|
|
provider = LocalProvider(model_path="test_model")
|
|
return provider
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_model_loading(self, local_provider):
|
|
"""Test local model loading"""
|
|
with patch.object(local_provider, "_load_model") as mock_load:
|
|
mock_load.return_value = MagicMock()
|
|
|
|
await local_provider.initialize()
|
|
|
|
assert local_provider.model is not None
|
|
mock_load.assert_called_once()
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_offline_generation(self, local_provider):
|
|
"""Test offline text generation"""
|
|
with patch.object(local_provider, "model") as mock_model:
|
|
mock_model.generate.return_value = ["Test response"]
|
|
|
|
result = await local_provider.generate_text("Test prompt")
|
|
|
|
assert "Test response" in str(result)
|
|
mock_model.generate.assert_called_once()
|
|
|
|
|
|
class TestAIProviderSelection:
|
|
"""Tests for AI provider selection and fallback logic"""
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_provider_priority_selection(self):
|
|
"""Test provider selection based on priority"""
|
|
provider1 = MagicMock()
|
|
provider1.priority = 3
|
|
provider1.is_healthy = True
|
|
|
|
provider2 = MagicMock()
|
|
provider2.priority = 1 # Higher priority (lower number)
|
|
provider2.is_healthy = True
|
|
|
|
providers = [provider1, provider2]
|
|
|
|
# Should select provider with highest priority (lowest number)
|
|
selected = min(
|
|
providers, key=lambda p: p.priority if p.is_healthy else float("inf")
|
|
)
|
|
|
|
assert selected == provider2
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_chain(self):
|
|
"""Test complete fallback chain execution"""
|
|
providers = []
|
|
|
|
# Create providers with different failure modes
|
|
for i in range(3):
|
|
provider = MagicMock()
|
|
provider.name = f"provider_{i}"
|
|
provider.priority = i + 1
|
|
if i < 2: # First two providers fail
|
|
provider.generate_text.side_effect = Exception(f"Error {i}")
|
|
else: # Last provider succeeds
|
|
provider.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
|
|
providers.append(provider)
|
|
|
|
# Test fallback logic
|
|
result = None
|
|
|
|
for provider in sorted(providers, key=lambda p: p.priority):
|
|
try:
|
|
result = await provider.generate_text("Test prompt")
|
|
break
|
|
except Exception:
|
|
continue
|
|
|
|
assert result == TestConfig.MOCK_AI_RESPONSE
|
|
assert providers[2].generate_text.called
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|