196 lines
6.0 KiB
Python
196 lines
6.0 KiB
Python
"""Vectorizer utility for generating embeddings."""
|
|
|
|
from types import TracebackType
|
|
from typing import Final, Self, cast
|
|
|
|
import httpx
|
|
|
|
from typings import EmbeddingResponse
|
|
|
|
from ..core.exceptions import VectorizationError
|
|
from ..core.models import StorageConfig, VectorConfig
|
|
from ..config import get_settings
|
|
|
|
JSON_CONTENT_TYPE: Final[str] = "application/json"
|
|
AUTHORIZATION_HEADER: Final[str] = "Authorization"
|
|
|
|
|
|
class Vectorizer:
|
|
"""Handles text vectorization using LLM endpoints."""
|
|
|
|
endpoint: str
|
|
model: str
|
|
dimension: int
|
|
|
|
def __init__(self, config: StorageConfig | VectorConfig):
|
|
"""
|
|
Initialize vectorizer.
|
|
|
|
Args:
|
|
config: Configuration with embedding details
|
|
"""
|
|
settings = get_settings()
|
|
if isinstance(config, StorageConfig):
|
|
# Extract vector config from global settings when storage config is provided
|
|
self.endpoint = str(settings.llm_endpoint).rstrip("/")
|
|
self.model = settings.embedding_model
|
|
self.dimension = settings.embedding_dimension
|
|
else:
|
|
self.endpoint = str(config.embedding_endpoint).rstrip("/")
|
|
self.model = config.model
|
|
self.dimension = config.dimension
|
|
|
|
resolved_api_key = settings.get_llm_api_key() or ""
|
|
headers: dict[str, str] = {"Content-Type": JSON_CONTENT_TYPE}
|
|
if resolved_api_key:
|
|
headers[AUTHORIZATION_HEADER] = f"Bearer {resolved_api_key}"
|
|
|
|
timeout_seconds = float(settings.request_timeout)
|
|
self.client = httpx.AsyncClient(timeout=timeout_seconds, headers=headers)
|
|
|
|
async def vectorize(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding vector for text.
|
|
|
|
Args:
|
|
text: Text to vectorize
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
if not text:
|
|
raise VectorizationError("Cannot vectorize empty text")
|
|
|
|
try:
|
|
return (
|
|
await self._ollama_embed(text)
|
|
if "ollama" in self.model
|
|
else await self._openai_embed(text)
|
|
)
|
|
except Exception as e:
|
|
raise VectorizationError(f"Vectorization failed: {e}") from e
|
|
|
|
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts.
|
|
|
|
Args:
|
|
texts: List of texts to vectorize
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
vectors: list[list[float]] = []
|
|
|
|
for text in texts:
|
|
vector = await self.vectorize(text)
|
|
vectors.append(vector)
|
|
|
|
return vectors
|
|
|
|
async def _ollama_embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding using Ollama via OpenAI-compatible endpoint.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
# Use the full model name as it appears in the API
|
|
model_name = self.model
|
|
|
|
# Use OpenAI-compatible endpoint for ollama models
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/embeddings",
|
|
json={
|
|
"model": model_name,
|
|
"input": text,
|
|
},
|
|
)
|
|
_ = response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
# Response is expected to be dict[str, object] from our type stub
|
|
|
|
response_data = cast(EmbeddingResponse, cast(object, response_json))
|
|
|
|
# Parse OpenAI-compatible response format
|
|
embeddings_list = response_data.get("data", [])
|
|
if not embeddings_list:
|
|
raise VectorizationError("No embeddings returned")
|
|
|
|
first_embedding = embeddings_list[0]
|
|
embedding_raw = first_embedding.get("embedding")
|
|
if not embedding_raw:
|
|
raise VectorizationError("Invalid embedding format")
|
|
|
|
# Convert to float list and validate
|
|
embedding: list[float] = []
|
|
embedding.extend(float(item) for item in embedding_raw)
|
|
# Ensure correct dimension
|
|
if len(embedding) != self.dimension:
|
|
raise VectorizationError(
|
|
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def _openai_embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding using OpenAI-compatible API.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/embeddings",
|
|
json={
|
|
"model": self.model,
|
|
"input": text,
|
|
},
|
|
)
|
|
_ = response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
# Response is expected to be dict[str, object] from our type stub
|
|
|
|
response_data = cast(EmbeddingResponse, cast(object, response_json))
|
|
|
|
embeddings_list = response_data.get("data", [])
|
|
if not embeddings_list:
|
|
raise VectorizationError("No embeddings returned")
|
|
|
|
first_embedding = embeddings_list[0]
|
|
embedding_raw = first_embedding.get("embedding")
|
|
if not embedding_raw:
|
|
raise VectorizationError("Invalid embedding format")
|
|
|
|
# Convert to float list and validate
|
|
embedding: list[float] = []
|
|
embedding.extend(float(item) for item in embedding_raw)
|
|
# Ensure correct dimension
|
|
if len(embedding) != self.dimension:
|
|
raise VectorizationError(
|
|
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Async context manager exit."""
|
|
await self.client.aclose()
|