from threading import Lock from chromadb.segment import ( SegmentImplementation, SegmentManager, MetadataReader, SegmentType, VectorReader, S, ) import logging from chromadb.segment.impl.manager.cache.cache import SegmentLRUCache, BasicCache,SegmentCache import os from chromadb.config import System, get_class from chromadb.db.system import SysDB from overrides import override from chromadb.segment.impl.vector.local_persistent_hnsw import ( PersistentLocalHnswSegment, ) from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) from chromadb.types import Collection, Operation, Segment, SegmentScope, Metadata from typing import Dict, Type, Sequence, Optional, cast from uuid import UUID, uuid4 import platform from chromadb.utils.lru_cache import LRUCache from chromadb.utils.directory import get_directory_size if platform.system() != "Windows": import resource elif platform.system() == "Windows": import ctypes SEGMENT_TYPE_IMPLS = { SegmentType.SQLITE: "chromadb.segment.impl.metadata.sqlite.SqliteMetadataSegment", SegmentType.HNSW_LOCAL_MEMORY: "chromadb.segment.impl.vector.local_hnsw.LocalHnswSegment", SegmentType.HNSW_LOCAL_PERSISTED: "chromadb.segment.impl.vector.local_persistent_hnsw.PersistentLocalHnswSegment", } class LocalSegmentManager(SegmentManager): _sysdb: SysDB _system: System _opentelemetry_client: OpenTelemetryClient _instances: Dict[UUID, SegmentImplementation] _vector_instances_file_handle_cache: LRUCache[ UUID, PersistentLocalHnswSegment ] # LRU cache to manage file handles across vector segment instances _vector_segment_type: SegmentType = SegmentType.HNSW_LOCAL_MEMORY _lock: Lock _max_file_handles: int def __init__(self, system: System): super().__init__(system) self._sysdb = self.require(SysDB) self._system = system self._opentelemetry_client = system.require(OpenTelemetryClient) self.logger = logging.getLogger(__name__) self._instances = {} self.segment_cache: Dict[SegmentScope, SegmentCache] = {SegmentScope.METADATA: BasicCache()} if system.settings.chroma_segment_cache_policy == "LRU" and system.settings.chroma_memory_limit_bytes > 0: self.segment_cache[SegmentScope.VECTOR] = SegmentLRUCache(capacity=system.settings.chroma_memory_limit_bytes,callback=lambda k, v: self.callback_cache_evict(v), size_func=lambda k: self._get_segment_disk_size(k)) else: self.segment_cache[SegmentScope.VECTOR] = BasicCache() self._lock = Lock() # TODO: prototyping with distributed segment for now, but this should be a configurable option # we need to think about how to handle this configuration if self._system.settings.require("is_persistent"): self._vector_segment_type = SegmentType.HNSW_LOCAL_PERSISTED if platform.system() != "Windows": self._max_file_handles = resource.getrlimit(resource.RLIMIT_NOFILE)[0] else: self._max_file_handles = ctypes.windll.msvcrt._getmaxstdio() # type: ignore segment_limit = ( self._max_file_handles // PersistentLocalHnswSegment.get_file_handle_count() ) self._vector_instances_file_handle_cache = LRUCache( segment_limit, callback=lambda _, v: v.close_persistent_index() ) def callback_cache_evict(self, segment: Segment): collection_id = segment["collection"] self.logger.info(f"LRU cache evict collection {collection_id}") instance = self._instance(segment) instance.stop() del self._instances[segment["id"]] @override def start(self) -> None: for instance in self._instances.values(): instance.start() super().start() @override def stop(self) -> None: for instance in self._instances.values(): instance.stop() super().stop() @override def reset_state(self) -> None: for instance in self._instances.values(): instance.stop() instance.reset_state() self._instances = {} self.segment_cache[SegmentScope.VECTOR].reset() super().reset_state() @trace_method( "LocalSegmentManager.create_segments", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override def create_segments(self, collection: Collection) -> Sequence[Segment]: vector_segment = _segment( self._vector_segment_type, SegmentScope.VECTOR, collection ) metadata_segment = _segment( SegmentType.SQLITE, SegmentScope.METADATA, collection ) return [vector_segment, metadata_segment] @trace_method( "LocalSegmentManager.delete_segments", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override def delete_segments(self, collection_id: UUID) -> Sequence[UUID]: segments = self._sysdb.get_segments(collection=collection_id) for segment in segments: if segment["id"] in self._instances: if segment["type"] == SegmentType.HNSW_LOCAL_PERSISTED.value: instance = self.get_segment(collection_id, VectorReader) instance.delete() elif segment["type"] == SegmentType.SQLITE.value: instance = self.get_segment(collection_id, MetadataReader) instance.delete() del self._instances[segment["id"]] if segment["scope"] is SegmentScope.VECTOR: self.segment_cache[SegmentScope.VECTOR].pop(collection_id) if segment["scope"] is SegmentScope.METADATA: self.segment_cache[SegmentScope.METADATA].pop(collection_id) return [s["id"] for s in segments] @trace_method( "LocalSegmentManager.get_segment", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) def _get_segment_disk_size(self, collection_id: UUID) -> int: segments = self._sysdb.get_segments(collection=collection_id, scope=SegmentScope.VECTOR) if len(segments) == 0: return 0 # With local segment manager (single server chroma), a collection always have one segment. size = get_directory_size( os.path.join(self._system.settings.require("persist_directory"), str(segments[0]["id"]))) return size def _get_segment_sysdb(self, collection_id:UUID, scope: SegmentScope): segments = self._sysdb.get_segments(collection=collection_id, scope=scope) known_types = set([k.value for k in SEGMENT_TYPE_IMPLS.keys()]) # Get the first segment of a known type segment = next(filter(lambda s: s["type"] in known_types, segments)) return segment @override def get_segment(self, collection_id: UUID, type: Type[S]) -> S: if type == MetadataReader: scope = SegmentScope.METADATA elif type == VectorReader: scope = SegmentScope.VECTOR else: raise ValueError(f"Invalid segment type: {type}") segment = self.segment_cache[scope].get(collection_id) if segment is None: segment = self._get_segment_sysdb(collection_id, scope) self.segment_cache[scope].set(collection_id, segment) # Instances must be atomically created, so we use a lock to ensure that only one thread # creates the instance. with self._lock: instance = self._instance(segment) return cast(S, instance) @trace_method( "LocalSegmentManager.hint_use_collection", OpenTelemetryGranularity.OPERATION_AND_SEGMENT, ) @override def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None: # The local segment manager responds to hints by pre-loading both the metadata and vector # segments for the given collection. for type in [MetadataReader, VectorReader]: # Just use get_segment to load the segment into the cache instance = self.get_segment(collection_id, type) # If the segment is a vector segment, we need to keep segments in an LRU cache # to avoid hitting the OS file handle limit. if type == VectorReader and self._system.settings.require("is_persistent"): instance = cast(PersistentLocalHnswSegment, instance) instance.open_persistent_index() self._vector_instances_file_handle_cache.set(collection_id, instance) def _cls(self, segment: Segment) -> Type[SegmentImplementation]: classname = SEGMENT_TYPE_IMPLS[SegmentType(segment["type"])] cls = get_class(classname, SegmentImplementation) return cls def _instance(self, segment: Segment) -> SegmentImplementation: if segment["id"] not in self._instances: cls = self._cls(segment) instance = cls(self._system, segment) instance.start() self._instances[segment["id"]] = instance return self._instances[segment["id"]] def _segment(type: SegmentType, scope: SegmentScope, collection: Collection) -> Segment: """Create a metadata dict, propagating metadata correctly for the given segment type.""" cls = get_class(SEGMENT_TYPE_IMPLS[type], SegmentImplementation) collection_metadata = collection.get("metadata", None) metadata: Optional[Metadata] = None if collection_metadata: metadata = cls.propagate_collection_metadata(collection_metadata) return Segment( id=uuid4(), type=type.value, scope=scope, topic=collection["topic"], collection=collection["id"], metadata=metadata )