Files
disbord/tests/test_ai_manager.py
Travis Vasceannie 3acb779569 chore: remove .env.example and add new files for project structure
- Deleted .env.example file as it is no longer needed.
- Added .gitignore to manage ignored files and directories.
- Introduced CLAUDE.md for AI provider integration documentation.
- Created dev.sh for development setup and scripts.
- Updated Dockerfile and Dockerfile.production for improved build processes.
- Added multiple test files and directories for comprehensive testing.
- Introduced new utility and service files for enhanced functionality.
- Organized codebase with new directories and files for better maintainability.
2025-08-27 23:00:19 -04:00

389 lines
13 KiB
Python

"""
Unit Tests for AI Manager
Tests for AI provider management, text generation, embeddings, and fallback mechanisms.
"""
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from core.ai_manager import AIProviderManager
from tests.conftest import TestConfig
class TestAIProviderManager:
"""Unit tests for AIProviderManager class"""
@pytest.fixture
async def ai_manager(self):
"""Create AI manager instance for testing"""
ai_manager = AIProviderManager()
await ai_manager.initialize()
yield ai_manager
await ai_manager.close()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_initialization(self, ai_manager):
"""Test AI manager initialization"""
assert ai_manager._initialized
assert len(ai_manager.providers) > 0
assert ai_manager.current_provider is not None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_provider_registration(self, ai_manager):
"""Test AI provider registration"""
# Mock provider
mock_provider = MagicMock()
mock_provider.name = "test_provider"
mock_provider.priority = 1
ai_manager.register_provider(mock_provider)
assert "test_provider" in ai_manager.providers
assert ai_manager.providers["test_provider"] == mock_provider
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_text_success(self, ai_manager):
"""Test successful text generation"""
with patch.object(ai_manager, "current_provider") as mock_provider:
mock_provider.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
result = await ai_manager.generate_text(
prompt="Test prompt", max_tokens=100
)
assert result == TestConfig.MOCK_AI_RESPONSE
mock_provider.generate_text.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_text_with_fallback(self, ai_manager):
"""Test text generation with provider fallback"""
# Mock primary provider failure
mock_primary = MagicMock()
mock_primary.name = "primary"
mock_primary.priority = 1
mock_primary.generate_text.side_effect = Exception("API Error")
# Mock fallback provider success
mock_fallback = MagicMock()
mock_fallback.name = "fallback"
mock_fallback.priority = 2
mock_fallback.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
ai_manager.providers = {"primary": mock_primary, "fallback": mock_fallback}
ai_manager.current_provider = mock_primary
result = await ai_manager.generate_text("Test prompt")
# Should use fallback provider
assert result == TestConfig.MOCK_AI_RESPONSE
mock_primary.generate_text.assert_called_once()
mock_fallback.generate_text.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embedding_success(self, ai_manager):
"""Test successful embedding generation"""
mock_embedding = [0.1, 0.2, 0.3] * 128 # 384-dim embedding
with patch.object(ai_manager, "current_provider") as mock_provider:
mock_provider.generate_embedding.return_value = mock_embedding
result = await ai_manager.generate_embedding("Test text")
assert result == mock_embedding
assert len(result) == 384
mock_provider.generate_embedding.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_provider_health_check(self, ai_manager):
"""Test provider health checking"""
with patch.object(ai_manager, "current_provider") as mock_provider:
mock_provider.check_health.return_value = {
"healthy": True,
"response_time": 0.5,
"rate_limit_remaining": 100,
}
health = await ai_manager.check_provider_health()
assert health["healthy"] is True
assert "response_time" in health
mock_provider.check_health.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_handling(self, ai_manager):
"""Test rate limit handling"""
mock_provider = MagicMock()
mock_provider.name = "rate_limited"
mock_provider.is_rate_limited.return_value = True
mock_provider.get_rate_limit_reset_time.return_value = datetime.utcnow()
ai_manager.current_provider = mock_provider
# Should handle rate limiting gracefully
with pytest.raises(Exception) as exc_info:
await ai_manager.generate_text("Test prompt")
assert "rate limit" in str(exc_info.value).lower()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_provider_switching(self, ai_manager):
"""Test automatic provider switching"""
# Create multiple mock providers
provider1 = MagicMock()
provider1.name = "provider1"
provider1.priority = 1
provider1.is_healthy = False
provider2 = MagicMock()
provider2.name = "provider2"
provider2.priority = 2
provider2.is_healthy = True
ai_manager.providers = {"provider1": provider1, "provider2": provider2}
# Should switch to healthy provider
await ai_manager._switch_to_best_provider()
assert ai_manager.current_provider == provider2
@pytest.mark.unit
@pytest.mark.asyncio
async def test_concurrent_requests(self, ai_manager):
"""Test handling concurrent AI requests"""
with patch.object(ai_manager, "current_provider") as mock_provider:
mock_provider.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
# Submit multiple concurrent requests
import asyncio
tasks = [ai_manager.generate_text(f"Prompt {i}") for i in range(5)]
results = await asyncio.gather(*tasks)
assert len(results) == 5
for result in results:
assert result == TestConfig.MOCK_AI_RESPONSE
@pytest.mark.unit
@pytest.mark.asyncio
async def test_provider_metrics(self, ai_manager):
"""Test provider performance metrics"""
with patch.object(ai_manager, "current_provider") as mock_provider:
mock_provider.get_metrics.return_value = {
"total_requests": 100,
"successful_requests": 95,
"failed_requests": 5,
"average_response_time": 1.2,
"tokens_used": 5000,
}
metrics = await ai_manager.get_provider_metrics()
assert metrics["total_requests"] == 100
assert metrics["successful_requests"] == 95
assert "average_response_time" in metrics
class TestOpenAIProvider:
"""Unit tests for OpenAI provider"""
@pytest.fixture
def openai_provider(self):
"""Create OpenAI provider for testing"""
with patch("openai.AsyncOpenAI"):
from core.ai_providers.openai_provider import OpenAIProvider
provider = OpenAIProvider(api_key="test_key")
return provider
@pytest.mark.unit
@pytest.mark.asyncio
async def test_text_generation(self, openai_provider):
"""Test OpenAI text generation"""
with patch.object(
openai_provider.client.chat.completions, "create"
) as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="Test response"))]
)
result = await openai_provider.generate_text("Test prompt")
assert result.choices[0].message.content == "Test response"
mock_create.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_embedding_generation(self, openai_provider):
"""Test OpenAI embedding generation"""
mock_embedding = [0.1] * 1536 # OpenAI embedding dimension
with patch.object(openai_provider.client.embeddings, "create") as mock_create:
mock_create.return_value = MagicMock(
data=[MagicMock(embedding=mock_embedding)]
)
result = await openai_provider.generate_embedding("Test text")
assert result == mock_embedding
mock_create.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_error_handling(self, openai_provider):
"""Test OpenAI error handling"""
with patch.object(
openai_provider.client.chat.completions, "create"
) as mock_create:
mock_create.side_effect = Exception("OpenAI API Error")
with pytest.raises(Exception) as exc_info:
await openai_provider.generate_text("Test prompt")
assert "OpenAI API Error" in str(exc_info.value)
class TestAnthropicProvider:
"""Unit tests for Anthropic provider"""
@pytest.fixture
def anthropic_provider(self):
"""Create Anthropic provider for testing"""
with patch("anthropic.AsyncAnthropic"):
from core.ai_providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(api_key="test_key")
return provider
@pytest.mark.unit
@pytest.mark.asyncio
async def test_text_generation(self, anthropic_provider):
"""Test Anthropic text generation"""
with patch.object(anthropic_provider.client.messages, "create") as mock_create:
mock_create.return_value = MagicMock(
content=[MagicMock(text="Test response")]
)
result = await anthropic_provider.generate_text("Test prompt")
assert result.content[0].text == "Test response"
mock_create.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_rate_limit_handling(self, anthropic_provider):
"""Test Anthropic rate limit handling"""
import anthropic
with patch.object(anthropic_provider.client.messages, "create") as mock_create:
mock_create.side_effect = anthropic.RateLimitError("Rate limit exceeded")
with pytest.raises(anthropic.RateLimitError):
await anthropic_provider.generate_text("Test prompt")
class TestLocalProvider:
"""Unit tests for local AI provider"""
@pytest.fixture
def local_provider(self):
"""Create local provider for testing"""
from core.ai_providers.local_provider import LocalProvider
provider = LocalProvider(model_path="test_model")
return provider
@pytest.mark.unit
@pytest.mark.asyncio
async def test_model_loading(self, local_provider):
"""Test local model loading"""
with patch.object(local_provider, "_load_model") as mock_load:
mock_load.return_value = MagicMock()
await local_provider.initialize()
assert local_provider.model is not None
mock_load.assert_called_once()
@pytest.mark.unit
@pytest.mark.asyncio
async def test_offline_generation(self, local_provider):
"""Test offline text generation"""
with patch.object(local_provider, "model") as mock_model:
mock_model.generate.return_value = ["Test response"]
result = await local_provider.generate_text("Test prompt")
assert "Test response" in str(result)
mock_model.generate.assert_called_once()
class TestAIProviderSelection:
"""Tests for AI provider selection and fallback logic"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_provider_priority_selection(self):
"""Test provider selection based on priority"""
provider1 = MagicMock()
provider1.priority = 3
provider1.is_healthy = True
provider2 = MagicMock()
provider2.priority = 1 # Higher priority (lower number)
provider2.is_healthy = True
providers = [provider1, provider2]
# Should select provider with highest priority (lowest number)
selected = min(
providers, key=lambda p: p.priority if p.is_healthy else float("inf")
)
assert selected == provider2
@pytest.mark.unit
@pytest.mark.asyncio
async def test_fallback_chain(self):
"""Test complete fallback chain execution"""
providers = []
# Create providers with different failure modes
for i in range(3):
provider = MagicMock()
provider.name = f"provider_{i}"
provider.priority = i + 1
if i < 2: # First two providers fail
provider.generate_text.side_effect = Exception(f"Error {i}")
else: # Last provider succeeds
provider.generate_text.return_value = TestConfig.MOCK_AI_RESPONSE
providers.append(provider)
# Test fallback logic
result = None
for provider in sorted(providers, key=lambda p: p.priority):
try:
result = await provider.generate_text("Test prompt")
break
except Exception:
continue
assert result == TestConfig.MOCK_AI_RESPONSE
assert providers[2].generate_text.called
if __name__ == "__main__":
pytest.main([__file__, "-v"])