Spaces:
Sleeping
Sleeping
import multiprocessing | |
import re | |
from typing import Any, Callable, Dict, Union | |
from chromadb.types import Metadata | |
Validator = Callable[[Union[str, int, float]], bool] | |
param_validators: Dict[str, Validator] = { | |
"hnsw:space": lambda p: bool(re.match(r"^(l2|cosine|ip)$", str(p))), | |
"hnsw:construction_ef": lambda p: isinstance(p, int), | |
"hnsw:search_ef": lambda p: isinstance(p, int), | |
"hnsw:M": lambda p: isinstance(p, int), | |
"hnsw:num_threads": lambda p: isinstance(p, int), | |
"hnsw:resize_factor": lambda p: isinstance(p, (int, float)), | |
} | |
# Extra params used for persistent hnsw | |
persistent_param_validators: Dict[str, Validator] = { | |
"hnsw:batch_size": lambda p: isinstance(p, int) and p > 2, | |
"hnsw:sync_threshold": lambda p: isinstance(p, int) and p > 2, | |
} | |
class Params: | |
def _select(metadata: Metadata) -> Dict[str, Any]: | |
segment_metadata = {} | |
for param, value in metadata.items(): | |
if param.startswith("hnsw:"): | |
segment_metadata[param] = value | |
return segment_metadata | |
def _validate(metadata: Dict[str, Any], validators: Dict[str, Validator]) -> None: | |
"""Validates the metadata""" | |
# Validate it | |
for param, value in metadata.items(): | |
if param not in validators: | |
raise ValueError(f"Unknown HNSW parameter: {param}") | |
if not validators[param](value): | |
raise ValueError(f"Invalid value for HNSW parameter: {param} = {value}") | |
class HnswParams(Params): | |
space: str | |
construction_ef: int | |
search_ef: int | |
M: int | |
num_threads: int | |
resize_factor: float | |
def __init__(self, metadata: Metadata): | |
metadata = metadata or {} | |
self.space = str(metadata.get("hnsw:space", "l2")) | |
self.construction_ef = int(metadata.get("hnsw:construction_ef", 100)) | |
self.search_ef = int(metadata.get("hnsw:search_ef", 10)) | |
self.M = int(metadata.get("hnsw:M", 16)) | |
self.num_threads = int( | |
metadata.get("hnsw:num_threads", multiprocessing.cpu_count()) | |
) | |
self.resize_factor = float(metadata.get("hnsw:resize_factor", 1.2)) | |
def extract(metadata: Metadata) -> Metadata: | |
"""Validate and return only the relevant hnsw params""" | |
segment_metadata = HnswParams._select(metadata) | |
HnswParams._validate(segment_metadata, param_validators) | |
return segment_metadata | |
class PersistentHnswParams(HnswParams): | |
batch_size: int | |
sync_threshold: int | |
def __init__(self, metadata: Metadata): | |
super().__init__(metadata) | |
self.batch_size = int(metadata.get("hnsw:batch_size", 100)) | |
self.sync_threshold = int(metadata.get("hnsw:sync_threshold", 1000)) | |
def extract(metadata: Metadata) -> Metadata: | |
"""Returns only the relevant hnsw params""" | |
all_validators = {**param_validators, **persistent_param_validators} | |
segment_metadata = PersistentHnswParams._select(metadata) | |
PersistentHnswParams._validate(segment_metadata, all_validators) | |
return segment_metadata | |