DigitalDan / utils /vector_db.py
HallD's picture
Upload 31 files
ad7b82e verified
raw
history blame
6.86 kB
from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Iterable, Optional, Sequence
import chromadb as cdb
import dotenv
from langchain_openai import OpenAIEmbeddings
from utils.text_processing import DocumentProcessing
dotenv.load_dotenv()
def _default_storage_path() -> str:
"""Return on-disk location for the Chroma persistent client."""
env_path = os.getenv("VECTOR_DB_PATH")
if env_path:
return env_path
project_root = Path(__file__).resolve().parent.parent
storage_dir = project_root / "data" / "chroma"
storage_dir.mkdir(parents=True, exist_ok=True)
return str(storage_dir)
class VectorDB:
"""Light wrapper around a persistent Chroma collection."""
def __init__(
self,
*,
collection_name: str = "me_profile",
persist_directory: Optional[str] = None,
embedding_model: Optional[OpenAIEmbeddings] = None,
) -> None:
self.persist_directory = persist_directory or _default_storage_path()
self.client = cdb.PersistentClient(path=self.persist_directory)
try:
self.collection = self.client.get_or_create_collection(collection_name)
except Exception:
# Fallback for older Chroma versions
self.collection = self.client.create_collection(collection_name)
self.embedding_model = embedding_model or OpenAIEmbeddings(
model="text-embedding-3-large",
api_key=os.getenv("OPENAI_API_KEY"),
)
# Auto-initialize from 'me/' if empty and the directory exists
try:
if self.collection.count() == 0 and os.path.isdir(os.path.join(Path(__file__).resolve().parent.parent, "me")):
try:
# Avoid circular import issues and heavy work if OPENAI key missing
if os.getenv("OPENAI_API_KEY"):
dp = DocumentProcessing()
dp.create_vector_db_from_directory(os.path.join(Path(__file__).resolve().parent.parent, "me"))
except Exception:
# If auto-build fails, continue with empty DB; app will still run
pass
except Exception:
pass
# ------------------------------------------------------------------
# Document ingestion helpers
# ------------------------------------------------------------------
def add_documents(
self,
documents: Sequence[str],
*,
metadatas: Optional[Sequence[dict[str, Any]]] = None,
ids: Optional[Sequence[str]] = None,
embeddings: Optional[Sequence[Sequence[float]]] = None,
) -> None:
"""Add documents to the Chroma collection."""
documents = list(documents)
if not documents:
return
count = len(documents)
if metadatas is None:
metadatas = [{} for _ in range(count)]
if ids is None:
ids = [f"doc_{i}" for i in range(count)]
if embeddings is None:
embeddings = self.embedding_model.embed_documents(list(documents))
self.collection.add(
documents=documents,
metadatas=list(metadatas),
ids=list(ids),
embeddings=list(embeddings),
)
# ------------------------------------------------------------------
# Query helpers
# ------------------------------------------------------------------
def query(
self,
query_texts: Iterable[str],
*,
k: int = 5,
include: Optional[Sequence[str]] = None,
) -> dict[str, Any]:
"""Query the collection with one or more natural-language strings."""
if isinstance(query_texts, str):
query_texts = [query_texts]
else:
query_texts = list(query_texts)
if not query_texts:
raise ValueError("query_texts must contain at least one string")
query_embeddings = self.embedding_model.embed_documents(list(query_texts))
return self.collection.query(
query_texts=list(query_texts),
query_embeddings=list(query_embeddings),
n_results=k,
include=include,
)
# ------------------------------------------------------------------
# Thin wrappers around underlying collection methods
# ------------------------------------------------------------------
def upsert(
self,
documents: Sequence[str],
*,
metadatas: Optional[Sequence[dict[str, Any]]] = None,
ids: Optional[Sequence[str]] = None,
embeddings: Optional[Sequence[Sequence[float]]] = None,
) -> None:
if embeddings is None:
embeddings = self.embedding_model.embed_documents(list(documents))
self.collection.upsert(
documents=list(documents),
metadatas=list(metadatas) if metadatas is not None else None,
ids=list(ids) if ids is not None else None,
embeddings=list(embeddings),
)
def delete(self, ids: Sequence[str]) -> None:
self.collection.delete(ids=list(ids))
def update(
self,
ids: Sequence[str],
documents: Optional[Sequence[str]] = None,
metadatas: Optional[Sequence[dict[str, Any]]] = None,
embeddings: Optional[Sequence[Sequence[float]]] = None,
) -> None:
if documents is not None and embeddings is None:
embeddings = self.embedding_model.embed_documents(list(documents))
self.collection.update(
ids=list(ids),
documents=list(documents) if documents is not None else None,
metadatas=list(metadatas) if metadatas is not None else None,
embeddings=list(embeddings) if embeddings is not None else None,
)
def get(self, ids: Sequence[str]) -> dict[str, Any]:
return self.collection.get(ids=list(ids))
def count(self) -> int:
return self.collection.count()
def list(self) -> list[str]:
return self.collection.list()
def delete_all(self) -> None:
self.collection.delete()
def get_all(self) -> dict[str, Any]:
return self.collection.get()
def get_all_metadata(self) -> list[dict[str, Any]]:
return self.collection.get(include=["metadatas"]) # type: ignore[return-value]
def get_all_ids(self) -> list[str]:
return self.collection.get(include=["ids"]).get("ids", []) # type: ignore[assignment]
def get_all_texts(self) -> list[str]:
return self.collection.get(include=["documents"]).get("documents", []) # type: ignore[assignment]
def get_all_embeddings(self) -> list[list[float]]:
return self.collection.get(include=["embeddings"]).get("embeddings", []) # type: ignore[assignment]