badalsahani's picture
feat: chroma initial deploy
287a0bc
import multiprocessing
import re
from typing import Any, Callable, Dict, Union
from chromadb.types import Metadata
Validator = Callable[[Union[str, int, float]], bool]
param_validators: Dict[str, Validator] = {
"hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))),
"hnsw:construction_ef": lambda p: isinstance(p, int),
"hnsw:search_ef": lambda p: isinstance(p, int),
"hnsw:M": lambda p: isinstance(p, int),
"hnsw:num_threads": lambda p: isinstance(p, int),
"hnsw:resize_factor": lambda p: isinstance(p, (int, float)),
}
# Extra params used for persistent hnsw
persistent_param_validators: Dict[str, Validator] = {
"hnsw:batch_size": lambda p: isinstance(p, int) and p > 2,
"hnsw:sync_threshold": lambda p: isinstance(p, int) and p > 2,
}
class Params:
@staticmethod
def _select(metadata: Metadata) -> Dict[str, Any]:
segment_metadata = {}
for param, value in metadata.items():
if param.startswith("hnsw:"):
segment_metadata[param] = value
return segment_metadata
@staticmethod
def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> None:
"""Validates the metadata"""
# Validate it
for param, value in metadata.items():
if param not in validators:
raise ValueError(f"Unknown HNSW parameter: {param}")
if not validators[param](value):
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}")
class HnswParams(Params):
space: str
construction_ef: int
search_ef: int
M: int
num_threads: int
resize_factor: float
def __init__(self, metadata: Metadata):
metadata = metadata or {}
self.space = str(metadata.get("hnsw:space", "l2"))
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100))
self.search_ef = int(metadata.get("hnsw:search_ef", 10))
self.M = int(metadata.get("hnsw:M", 16))
self.num_threads = int(
metadata.get("hnsw:num_threads", multiprocessing.cpu_count())
)
self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2))
@staticmethod
def extract(metadata: Metadata) -> Metadata:
"""Validate and return only the relevant hnsw params"""
segment_metadata = HnswParams._select(metadata)
HnswParams._validate(segment_metadata, param_validators)
return segment_metadata
class PersistentHnswParams(HnswParams):
batch_size: int
sync_threshold: int
def __init__(self, metadata: Metadata):
super().__init__(metadata)
self.batch_size = int(metadata.get("hnsw:batch_size", 100))
self.sync_threshold = int(metadata.get("hnsw:sync_threshold", 1000))
@staticmethod
def extract(metadata: Metadata) -> Metadata:
"""Returns only the relevant hnsw params"""
all_validators = {**param_validators, **persistent_param_validators}
segment_metadata = PersistentHnswParams._select(metadata)
PersistentHnswParams._validate(segment_metadata, all_validators)
return segment_metadata