File size: 3,649 Bytes
287a0bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from overrides import EnforceOverrides, override
from typing import List, Optional, Sequence
from chromadb.config import System
from chromadb.proto.convert import (
    from_proto_vector_embedding_record,
    from_proto_vector_query_result,
    to_proto_vector,
)
from chromadb.segment import VectorReader
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
    OpenTelemetryClient,
    OpenTelemetryGranularity,
    trace_method,
)
from chromadb.types import (
    Metadata,
    ScalarEncoding,
    Segment,
    VectorEmbeddingRecord,
    VectorQuery,
    VectorQueryResult,
)
from chromadb.proto.chroma_pb2_grpc import VectorReaderStub
from chromadb.proto.chroma_pb2 import (
    GetVectorsRequest,
    GetVectorsResponse,
    QueryVectorsRequest,
    QueryVectorsResponse,
)
import grpc


class GrpcVectorSegment(VectorReader, EnforceOverrides):
    _vector_reader_stub: VectorReaderStub
    _segment: Segment
    _opentelemetry_client: OpenTelemetryClient

    def __init__(self, system: System, segment: Segment):
        # TODO: move to start() method
        # TODO: close channel in stop() method
        if segment["metadata"] is None or segment["metadata"]["grpc_url"] is None:
            raise Exception("Missing grpc_url in segment metadata")

        channel = grpc.insecure_channel(segment["metadata"]["grpc_url"])
        self._vector_reader_stub = VectorReaderStub(channel)  # type: ignore
        self._segment = segment
        self._opentelemetry_client = system.require(OpenTelemetryClient)

    @trace_method("GrpcVectorSegment.get_vectors", OpenTelemetryGranularity.ALL)
    @override
    def get_vectors(
        self, ids: Optional[Sequence[str]] = None
    ) -> Sequence[VectorEmbeddingRecord]:
        request = GetVectorsRequest(ids=ids, segment_id=self._segment["id"].hex)
        response: GetVectorsResponse = self._vector_reader_stub.GetVectors(request)
        results: List[VectorEmbeddingRecord] = []
        for vector in response.records:
            result = from_proto_vector_embedding_record(vector)
            results.append(result)
        return results

    @trace_method("GrpcVectorSegment.query_vectors", OpenTelemetryGranularity.ALL)
    @override
    def query_vectors(
        self, query: VectorQuery
    ) -> Sequence[Sequence[VectorQueryResult]]:
        request = QueryVectorsRequest(
            vectors=[
                to_proto_vector(vector=v, encoding=ScalarEncoding.FLOAT32)
                for v in query["vectors"]
            ],
            k=query["k"],
            allowed_ids=query["allowed_ids"],
            include_embeddings=query["include_embeddings"],
            segment_id=self._segment["id"].hex,
        )
        response: QueryVectorsResponse = self._vector_reader_stub.QueryVectors(request)
        results: List[List[VectorQueryResult]] = []
        for result in response.results:
            curr_result: List[VectorQueryResult] = []
            for r in result.results:
                curr_result.append(from_proto_vector_query_result(r))
            results.append(curr_result)
        return results

    @override
    def count(self) -> int:
        raise NotImplementedError()

    @override
    def max_seqid(self) -> int:
        return 0

    @staticmethod
    @override
    def propagate_collection_metadata(metadata: Metadata) -> Optional[Metadata]:
        # Great example of why language sharing is nice.
        segment_metadata = PersistentHnswParams.extract(metadata)
        return segment_metadata

    @override
    def delete(self) -> None:
        raise NotImplementedError()