201 lines
5.9 KiB
Python
201 lines
5.9 KiB
Python
"""Vectorizer utility for generating embeddings."""
|
|
|
|
from types import TracebackType
|
|
from typing import Self, cast
|
|
|
|
import httpx
|
|
|
|
from typings import EmbeddingResponse
|
|
|
|
from ..core.exceptions import VectorizationError
|
|
from ..core.models import StorageConfig, VectorConfig
|
|
|
|
|
|
class Vectorizer:
|
|
"""Handles text vectorization using LLM endpoints."""
|
|
|
|
endpoint: str
|
|
model: str
|
|
dimension: int
|
|
|
|
def __init__(self, config: StorageConfig | VectorConfig):
|
|
"""
|
|
Initialize vectorizer.
|
|
|
|
Args:
|
|
config: Configuration with embedding details
|
|
"""
|
|
if isinstance(config, StorageConfig):
|
|
# Extract vector config from storage config
|
|
self.endpoint = "http://llm.lab"
|
|
self.model = "ollama/bge-m3"
|
|
self.dimension = 1024
|
|
else:
|
|
self.endpoint = str(config.embedding_endpoint)
|
|
self.model = config.model
|
|
self.dimension = config.dimension
|
|
|
|
# Get API key from environment
|
|
import os
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
# Load .env from the project root
|
|
env_path = Path(__file__).parent.parent.parent / ".env"
|
|
_ = load_dotenv(env_path)
|
|
|
|
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
if api_key:
|
|
headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
self.client = httpx.AsyncClient(timeout=60.0, headers=headers) # type: ignore[attr-defined]
|
|
|
|
async def vectorize(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding vector for text.
|
|
|
|
Args:
|
|
text: Text to vectorize
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
if not text:
|
|
raise VectorizationError("Cannot vectorize empty text")
|
|
|
|
try:
|
|
return (
|
|
await self._ollama_embed(text)
|
|
if "ollama" in self.model
|
|
else await self._openai_embed(text)
|
|
)
|
|
except Exception as e:
|
|
raise VectorizationError(f"Vectorization failed: {e}") from e
|
|
|
|
async def vectorize_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""
|
|
Generate embeddings for multiple texts.
|
|
|
|
Args:
|
|
texts: List of texts to vectorize
|
|
|
|
Returns:
|
|
List of embedding vectors
|
|
"""
|
|
vectors: list[list[float]] = []
|
|
|
|
for text in texts:
|
|
vector = await self.vectorize(text)
|
|
vectors.append(vector)
|
|
|
|
return vectors
|
|
|
|
async def _ollama_embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding using Ollama via OpenAI-compatible endpoint.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
# Use the full model name as it appears in the API
|
|
model_name = self.model
|
|
|
|
# Use OpenAI-compatible endpoint for ollama models
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/embeddings",
|
|
json={
|
|
"model": model_name,
|
|
"input": text,
|
|
},
|
|
)
|
|
_ = response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
# Response is expected to be dict[str, object] from our type stub
|
|
|
|
response_data = cast(EmbeddingResponse, cast(object, response_json))
|
|
|
|
# Parse OpenAI-compatible response format
|
|
embeddings_list = response_data.get("data", [])
|
|
if not embeddings_list:
|
|
raise VectorizationError("No embeddings returned")
|
|
|
|
first_embedding = embeddings_list[0]
|
|
embedding_raw = first_embedding.get("embedding")
|
|
if not embedding_raw:
|
|
raise VectorizationError("Invalid embedding format")
|
|
|
|
# Convert to float list and validate
|
|
embedding: list[float] = []
|
|
embedding.extend(float(item) for item in embedding_raw)
|
|
# Ensure correct dimension
|
|
if len(embedding) != self.dimension:
|
|
raise VectorizationError(
|
|
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def _openai_embed(self, text: str) -> list[float]:
|
|
"""
|
|
Generate embedding using OpenAI-compatible API.
|
|
|
|
Args:
|
|
text: Text to embed
|
|
|
|
Returns:
|
|
Embedding vector
|
|
"""
|
|
response = await self.client.post(
|
|
f"{self.endpoint}/v1/embeddings",
|
|
json={
|
|
"model": self.model,
|
|
"input": text,
|
|
},
|
|
)
|
|
_ = response.raise_for_status()
|
|
|
|
response_json = response.json()
|
|
# Response is expected to be dict[str, object] from our type stub
|
|
|
|
response_data = cast(EmbeddingResponse, cast(object, response_json))
|
|
|
|
embeddings_list = response_data.get("data", [])
|
|
if not embeddings_list:
|
|
raise VectorizationError("No embeddings returned")
|
|
|
|
first_embedding = embeddings_list[0]
|
|
embedding_raw = first_embedding.get("embedding")
|
|
if not embedding_raw:
|
|
raise VectorizationError("Invalid embedding format")
|
|
|
|
# Convert to float list and validate
|
|
embedding: list[float] = []
|
|
embedding.extend(float(item) for item in embedding_raw)
|
|
# Ensure correct dimension
|
|
if len(embedding) != self.dimension:
|
|
raise VectorizationError(
|
|
f"Embedding dimension mismatch: expected {self.dimension}, received {len(embedding)}"
|
|
)
|
|
|
|
return embedding
|
|
|
|
async def __aenter__(self) -> Self:
|
|
"""Async context manager entry."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Async context manager exit."""
|
|
await self.client.aclose()
|