badalsahani's picture
feat: chroma initial deploy
287a0bc
from overrides import EnforceOverrides, override
from typing import List, Optional, Sequence
from chromadb.config import System
from chromadb.proto.convert import (
from_proto_vector_embedding_record,
from_proto_vector_query_result,
to_proto_vector,
)
from chromadb.segment import VectorReader
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.types import (
Metadata,
ScalarEncoding,
Segment,
VectorEmbeddingRecord,
VectorQuery,
VectorQueryResult,
)
from chromadb.proto.chroma_pb2_grpc import VectorReaderStub
from chromadb.proto.chroma_pb2 import (
GetVectorsRequest,
GetVectorsResponse,
QueryVectorsRequest,
QueryVectorsResponse,
)
import grpc
class GrpcVectorSegment(VectorReader, EnforceOverrides):
_vector_reader_stub: VectorReaderStub
_segment: Segment
_opentelemetry_client: OpenTelemetryClient
def __init__(self, system: System, segment: Segment):
# TODO: move to start() method
# TODO: close channel in stop() method
if segment["metadata"] is None or segment["metadata"]["grpc_url"] is None:
raise Exception("Missing grpc_url in segment metadata")
channel = grpc.insecure_channel(segment["metadata"]["grpc_url"])
self._vector_reader_stub = VectorReaderStub(channel) # type: ignore
self._segment = segment
self._opentelemetry_client = system.require(OpenTelemetryClient)
@trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL)
@override
def get_vectors(
self, ids: Optional[Sequence[str]] = None
) -> Sequence[VectorEmbeddingRecord]:
request = GetVectorsRequest(ids=ids, segment_id=self._segment["id"].hex)
response: GetVectorsResponse = self._vector_reader_stub.GetVectors(request)
results: List[VectorEmbeddingRecord] = []
for vector in response.records:
result = from_proto_vector_embedding_record(vector)
results.append(result)
return results
@trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL)
@override
def query_vectors(
self, query: VectorQuery
) -> Sequence[Sequence[VectorQueryResult]]:
request = QueryVectorsRequest(
vectors=[
to_proto_vector(vector=v, encoding=ScalarEncoding.FLOAT32)
for v in query["vectors"]
],
k=query["k"],
allowed_ids=query["allowed_ids"],
include_embeddings=query["include_embeddings"],
segment_id=self._segment["id"].hex,
)
response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(request)
results: List[List[VectorQueryResult]] = []
for result in response.results:
curr_result: List[VectorQueryResult] = []
for r in result.results:
curr_result.append(from_proto_vector_query_result(r))
results.append(curr_result)
return results
@override
def count(self) -> int:
raise NotImplementedError()
@override
def max_seqid(self) -> int:
return 0
@staticmethod
@override
def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
# Great example of why language sharing is nice.
segment_metadata = PersistentHnswParams.extract(metadata)
return segment_metadata
@override
def delete(self) -> None:
raise NotImplementedError()