Files
rag-manager/ingest_pipeline/utils/vectorizer.py
2025-09-18 09:44:16 +00:00

201 lines
5.9 KiB
Python

"""Vectorizer utility for generating embeddings."""
from types import TracebackType
from typing import Self, cast
import httpx
from typings import EmbeddingResponse
from ..core.exceptions import VectorizationError
from ..core.models import StorageConfig, VectorConfig
class Vectorizer:
"""Handles text vectorization using LLM endpoints."""
endpoint: str
model: str
dimension: int
def __init__(self, config: StorageConfig | VectorConfig):
"""
Initialize vectorizer.
Args:
config: Configuration with embedding details
"""
if isinstance(config, StorageConfig):
# Extract vector config from storage config
self.endpoint = "http://llm.lab"
self.model = "ollama/bge-m3"
self.dimension = 1024
else:
self.endpoint = str(config.embedding_endpoint)
self.model = config.model
self.dimension = config.dimension
# Get API key from environment
import os
from pathlib import Path
from dotenv import load_dotenv
# Load .env from the project root
env_path = Path(__file__).parent.parent.parent / ".env"
_ = load_dotenv(env_path)
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
self.client = httpx.AsyncClient(timeout=60.0, headers=headers) # type: ignore[attr-defined]
async def vectorize(self, text: str) -> list[float]:
"""
Generate embedding vector for text.
Args:
text: Text to vectorize
Returns:
Embedding vector
"""
if not text:
raise VectorizationError("Cannot vectorize empty text")
try:
return (
await self._ollama_embed(text)
if "ollama" in self.model
else await self._openai_embed(text)
)
except Exception as e:
raise VectorizationError(f"Vectorization failed: {e}") from e
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for multiple texts.
Args:
texts: List of texts to vectorize
Returns:
List of embedding vectors
"""
vectors: list[list[float]] = []
for text in texts:
vector = await self.vectorize(text)
vectors.append(vector)
return vectors
async def _ollama_embed(self, text: str) -> list[float]:
"""
Generate embedding using Ollama via OpenAI-compatible endpoint.
Args:
text: Text to embed
Returns:
Embedding vector
"""
# Use the full model name as it appears in the API
model_name = self.model
# Use OpenAI-compatible endpoint for ollama models
response = await self.client.post(
f"{self.endpoint}/v1/embeddings",
json={
"model": model_name,
"input": text,
},
)
_ = response.raise_for_status()
response_json = response.json()
# Response is expected to be dict[str, object] from our type stub
response_data = cast(EmbeddingResponse, cast(object, response_json))
# Parse OpenAI-compatible response format
embeddings_list = response_data.get("data", [])
if not embeddings_list:
raise VectorizationError("No embeddings returned")
first_embedding = embeddings_list[0]
embedding_raw = first_embedding.get("embedding")
if not embedding_raw:
raise VectorizationError("Invalid embedding format")
# Convert to float list and validate
embedding: list[float] = []
embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
)
return embedding
async def _openai_embed(self, text: str) -> list[float]:
"""
Generate embedding using OpenAI-compatible API.
Args:
text: Text to embed
Returns:
Embedding vector
"""
response = await self.client.post(
f"{self.endpoint}/v1/embeddings",
json={
"model": self.model,
"input": text,
},
)
_ = response.raise_for_status()
response_json = response.json()
# Response is expected to be dict[str, object] from our type stub
response_data = cast(EmbeddingResponse, cast(object, response_json))
embeddings_list = response_data.get("data", [])
if not embeddings_list:
raise VectorizationError("No embeddings returned")
first_embedding = embeddings_list[0]
embedding_raw = first_embedding.get("embedding")
if not embedding_raw:
raise VectorizationError("Invalid embedding format")
# Convert to float list and validate
embedding: list[float] = []
embedding.extend(float(item) for item in embedding_raw)
# Ensure correct dimension
if len(embedding) != self.dimension:
raise VectorizationError(
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
)
return embedding
async def __aenter__(self) -> Self:
"""Async context manager entry."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Async context manager exit."""
await self.client.aclose()