Spaces:
Running
Running
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Iterable, Optional, Sequence | |
| import chromadb as cdb | |
| import dotenv | |
| from langchain_openai import OpenAIEmbeddings | |
| from utils.text_processing import DocumentProcessing | |
| dotenv.load_dotenv() | |
| def _default_storage_path() -> str: | |
| """Return on-disk location for the Chroma persistent client.""" | |
| env_path = os.getenv("VECTOR_DB_PATH") | |
| if env_path: | |
| return env_path | |
| project_root = Path(__file__).resolve().parent.parent | |
| storage_dir = project_root / "data" / "chroma" | |
| storage_dir.mkdir(parents=True, exist_ok=True) | |
| return str(storage_dir) | |
| class VectorDB: | |
| """Light wrapper around a persistent Chroma collection.""" | |
| def __init__( | |
| self, | |
| *, | |
| collection_name: str = "me_profile", | |
| persist_directory: Optional[str] = None, | |
| embedding_model: Optional[OpenAIEmbeddings] = None, | |
| ) -> None: | |
| self.persist_directory = persist_directory or _default_storage_path() | |
| self.client = cdb.PersistentClient(path=self.persist_directory) | |
| try: | |
| self.collection = self.client.get_or_create_collection(collection_name) | |
| except Exception: | |
| # Fallback for older Chroma versions | |
| self.collection = self.client.create_collection(collection_name) | |
| self.embedding_model = embedding_model or OpenAIEmbeddings( | |
| model="text-embedding-3-large", | |
| api_key=os.getenv("OPENAI_API_KEY"), | |
| ) | |
| # Auto-initialize from 'me/' if empty and the directory exists | |
| try: | |
| if self.collection.count() == 0 and os.path.isdir(os.path.join(Path(__file__).resolve().parent.parent, "me")): | |
| try: | |
| # Avoid circular import issues and heavy work if OPENAI key missing | |
| if os.getenv("OPENAI_API_KEY"): | |
| dp = DocumentProcessing() | |
| dp.create_vector_db_from_directory(os.path.join(Path(__file__).resolve().parent.parent, "me")) | |
| except Exception: | |
| # If auto-build fails, continue with empty DB; app will still run | |
| pass | |
| except Exception: | |
| pass | |
| # ------------------------------------------------------------------ | |
| # Document ingestion helpers | |
| # ------------------------------------------------------------------ | |
| def add_documents( | |
| self, | |
| documents: Sequence[str], | |
| *, | |
| metadatas: Optional[Sequence[dict[str, Any]]] = None, | |
| ids: Optional[Sequence[str]] = None, | |
| embeddings: Optional[Sequence[Sequence[float]]] = None, | |
| ) -> None: | |
| """Add documents to the Chroma collection.""" | |
| documents = list(documents) | |
| if not documents: | |
| return | |
| count = len(documents) | |
| if metadatas is None: | |
| metadatas = [{} for _ in range(count)] | |
| if ids is None: | |
| ids = [f"doc_{i}" for i in range(count)] | |
| if embeddings is None: | |
| embeddings = self.embedding_model.embed_documents(list(documents)) | |
| self.collection.add( | |
| documents=documents, | |
| metadatas=list(metadatas), | |
| ids=list(ids), | |
| embeddings=list(embeddings), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Query helpers | |
| # ------------------------------------------------------------------ | |
| def query( | |
| self, | |
| query_texts: Iterable[str], | |
| *, | |
| k: int = 5, | |
| include: Optional[Sequence[str]] = None, | |
| ) -> dict[str, Any]: | |
| """Query the collection with one or more natural-language strings.""" | |
| if isinstance(query_texts, str): | |
| query_texts = [query_texts] | |
| else: | |
| query_texts = list(query_texts) | |
| if not query_texts: | |
| raise ValueError("query_texts must contain at least one string") | |
| query_embeddings = self.embedding_model.embed_documents(list(query_texts)) | |
| return self.collection.query( | |
| query_texts=list(query_texts), | |
| query_embeddings=list(query_embeddings), | |
| n_results=k, | |
| include=include, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Thin wrappers around underlying collection methods | |
| # ------------------------------------------------------------------ | |
| def upsert( | |
| self, | |
| documents: Sequence[str], | |
| *, | |
| metadatas: Optional[Sequence[dict[str, Any]]] = None, | |
| ids: Optional[Sequence[str]] = None, | |
| embeddings: Optional[Sequence[Sequence[float]]] = None, | |
| ) -> None: | |
| if embeddings is None: | |
| embeddings = self.embedding_model.embed_documents(list(documents)) | |
| self.collection.upsert( | |
| documents=list(documents), | |
| metadatas=list(metadatas) if metadatas is not None else None, | |
| ids=list(ids) if ids is not None else None, | |
| embeddings=list(embeddings), | |
| ) | |
| def delete(self, ids: Sequence[str]) -> None: | |
| self.collection.delete(ids=list(ids)) | |
| def update( | |
| self, | |
| ids: Sequence[str], | |
| documents: Optional[Sequence[str]] = None, | |
| metadatas: Optional[Sequence[dict[str, Any]]] = None, | |
| embeddings: Optional[Sequence[Sequence[float]]] = None, | |
| ) -> None: | |
| if documents is not None and embeddings is None: | |
| embeddings = self.embedding_model.embed_documents(list(documents)) | |
| self.collection.update( | |
| ids=list(ids), | |
| documents=list(documents) if documents is not None else None, | |
| metadatas=list(metadatas) if metadatas is not None else None, | |
| embeddings=list(embeddings) if embeddings is not None else None, | |
| ) | |
| def get(self, ids: Sequence[str]) -> dict[str, Any]: | |
| return self.collection.get(ids=list(ids)) | |
| def count(self) -> int: | |
| return self.collection.count() | |
| def list(self) -> list[str]: | |
| return self.collection.list() | |
| def delete_all(self) -> None: | |
| self.collection.delete() | |
| def get_all(self) -> dict[str, Any]: | |
| return self.collection.get() | |
| def get_all_metadata(self) -> list[dict[str, Any]]: | |
| return self.collection.get(include=["metadatas"]) # type: ignore[return-value] | |
| def get_all_ids(self) -> list[str]: | |
| return self.collection.get(include=["ids"]).get("ids", []) # type: ignore[assignment] | |
| def get_all_texts(self) -> list[str]: | |
| return self.collection.get(include=["documents"]).get("documents", []) # type: ignore[assignment] | |
| def get_all_embeddings(self) -> list[list[float]]: | |
| return self.collection.get(include=["embeddings"]).get("embeddings", []) # type: ignore[assignment] |