Spaces:
Sleeping
Sleeping
import array | |
from uuid import UUID | |
from typing import Dict, Optional, Tuple, Union, cast | |
from chromadb.api.types import Embedding | |
import chromadb.proto.chroma_pb2 as proto | |
from chromadb.utils.messageid import bytes_to_int, int_to_bytes | |
from chromadb.types import ( | |
Collection, | |
EmbeddingRecord, | |
Metadata, | |
Operation, | |
ScalarEncoding, | |
Segment, | |
SegmentScope, | |
SeqId, | |
SubmitEmbeddingRecord, | |
UpdateMetadata, | |
Vector, | |
VectorEmbeddingRecord, | |
VectorQueryResult, | |
) | |
# TODO: Unit tests for this file, handling optional states etc | |
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector: | |
if encoding == ScalarEncoding.FLOAT32: | |
as_bytes = array.array("f", vector).tobytes() | |
proto_encoding = proto.ScalarEncoding.FLOAT32 | |
elif encoding == ScalarEncoding.INT32: | |
as_bytes = array.array("i", vector).tobytes() | |
proto_encoding = proto.ScalarEncoding.INT32 | |
else: | |
raise ValueError( | |
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \ | |
or {ScalarEncoding.INT32}" | |
) | |
return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding) | |
def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]: | |
encoding = vector.encoding | |
as_array: Union[array.array[float], array.array[int]] | |
if encoding == proto.ScalarEncoding.FLOAT32: | |
as_array = array.array("f") | |
out_encoding = ScalarEncoding.FLOAT32 | |
elif encoding == proto.ScalarEncoding.INT32: | |
as_array = array.array("i") | |
out_encoding = ScalarEncoding.INT32 | |
else: | |
raise ValueError( | |
f"Unknown encoding {encoding}, expected one of \ | |
{proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}" | |
) | |
as_array.frombytes(vector.vector) | |
return (as_array.tolist(), out_encoding) | |
def from_proto_operation(operation: proto.Operation) -> Operation: | |
if operation == proto.Operation.ADD: | |
return Operation.ADD | |
elif operation == proto.Operation.UPDATE: | |
return Operation.UPDATE | |
elif operation == proto.Operation.UPSERT: | |
return Operation.UPSERT | |
elif operation == proto.Operation.DELETE: | |
return Operation.DELETE | |
else: | |
# TODO: full error | |
raise RuntimeError(f"Unknown operation {operation}") | |
def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: | |
return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False)) | |
def from_proto_update_metadata( | |
metadata: proto.UpdateMetadata, | |
) -> Optional[UpdateMetadata]: | |
return cast( | |
Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True) | |
) | |
def _from_proto_metadata_handle_none( | |
metadata: proto.UpdateMetadata, is_update: bool | |
) -> Optional[Union[UpdateMetadata, Metadata]]: | |
if not metadata.metadata: | |
return None | |
out_metadata: Dict[str, Union[str, int, float, None]] = {} | |
for key, value in metadata.metadata.items(): | |
if value.HasField("string_value"): | |
out_metadata[key] = value.string_value | |
elif value.HasField("int_value"): | |
out_metadata[key] = value.int_value | |
elif value.HasField("float_value"): | |
out_metadata[key] = value.float_value | |
elif is_update: | |
out_metadata[key] = None | |
else: | |
raise ValueError(f"Metadata key {key} value cannot be None") | |
return out_metadata | |
def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata: | |
return proto.UpdateMetadata( | |
metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()} | |
) | |
def from_proto_submit( | |
submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId | |
) -> EmbeddingRecord: | |
embedding, encoding = from_proto_vector(submit_embedding_record.vector) | |
record = EmbeddingRecord( | |
id=submit_embedding_record.id, | |
seq_id=seq_id, | |
embedding=embedding, | |
encoding=encoding, | |
metadata=from_proto_update_metadata(submit_embedding_record.metadata), | |
operation=from_proto_operation(submit_embedding_record.operation), | |
collection_id=UUID(hex=submit_embedding_record.collection_id), | |
) | |
return record | |
def from_proto_segment(segment: proto.Segment) -> Segment: | |
return Segment( | |
id=UUID(hex=segment.id), | |
type=segment.type, | |
scope=from_proto_segment_scope(segment.scope), | |
topic=segment.topic if segment.HasField("topic") else None, | |
collection=None | |
if not segment.HasField("collection") | |
else UUID(hex=segment.collection), | |
metadata=from_proto_metadata(segment.metadata) | |
if segment.HasField("metadata") | |
else None, | |
) | |
def to_proto_segment(segment: Segment) -> proto.Segment: | |
return proto.Segment( | |
id=segment["id"].hex, | |
type=segment["type"], | |
scope=to_proto_segment_scope(segment["scope"]), | |
topic=segment["topic"], | |
collection=None if segment["collection"] is None else segment["collection"].hex, | |
metadata=None | |
if segment["metadata"] is None | |
else to_proto_update_metadata(segment["metadata"]), | |
) | |
def from_proto_segment_scope(segment_scope: proto.SegmentScope) -> SegmentScope: | |
if segment_scope == proto.SegmentScope.VECTOR: | |
return SegmentScope.VECTOR | |
elif segment_scope == proto.SegmentScope.METADATA: | |
return SegmentScope.METADATA | |
else: | |
raise RuntimeError(f"Unknown segment scope {segment_scope}") | |
def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope: | |
if segment_scope == SegmentScope.VECTOR: | |
return proto.SegmentScope.VECTOR | |
elif segment_scope == SegmentScope.METADATA: | |
return proto.SegmentScope.METADATA | |
else: | |
raise RuntimeError(f"Unknown segment scope {segment_scope}") | |
def to_proto_metadata_update_value( | |
value: Union[str, int, float, None] | |
) -> proto.UpdateMetadataValue: | |
if isinstance(value, str): | |
return proto.UpdateMetadataValue(string_value=value) | |
elif isinstance(value, int): | |
return proto.UpdateMetadataValue(int_value=value) | |
elif isinstance(value, float): | |
return proto.UpdateMetadataValue(float_value=value) | |
elif value is None: | |
return proto.UpdateMetadataValue() | |
else: | |
raise ValueError( | |
f"Unknown metadata value type {type(value)}, expected one of str, int, \ | |
float, or None" | |
) | |
def from_proto_collection(collection: proto.Collection) -> Collection: | |
return Collection( | |
id=UUID(hex=collection.id), | |
name=collection.name, | |
topic=collection.topic, | |
metadata=from_proto_metadata(collection.metadata) | |
if collection.HasField("metadata") | |
else None, | |
dimension=collection.dimension | |
if collection.HasField("dimension") and collection.dimension | |
else None, | |
database=collection.database, | |
tenant=collection.tenant, | |
) | |
def to_proto_collection(collection: Collection) -> proto.Collection: | |
return proto.Collection( | |
id=collection["id"].hex, | |
name=collection["name"], | |
topic=collection["topic"], | |
metadata=None | |
if collection["metadata"] is None | |
else to_proto_update_metadata(collection["metadata"]), | |
dimension=collection["dimension"], | |
tenant=collection["tenant"], | |
database=collection["database"], | |
) | |
def to_proto_operation(operation: Operation) -> proto.Operation: | |
if operation == Operation.ADD: | |
return proto.Operation.ADD | |
elif operation == Operation.UPDATE: | |
return proto.Operation.UPDATE | |
elif operation == Operation.UPSERT: | |
return proto.Operation.UPSERT | |
elif operation == Operation.DELETE: | |
return proto.Operation.DELETE | |
else: | |
raise ValueError( | |
f"Unknown operation {operation}, expected one of {Operation.ADD}, \ | |
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}" | |
) | |
def to_proto_submit( | |
submit_record: SubmitEmbeddingRecord, | |
) -> proto.SubmitEmbeddingRecord: | |
vector = None | |
if submit_record["embedding"] is not None and submit_record["encoding"] is not None: | |
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"]) | |
metadata = None | |
if submit_record["metadata"] is not None: | |
metadata = to_proto_update_metadata(submit_record["metadata"]) | |
return proto.SubmitEmbeddingRecord( | |
id=submit_record["id"], | |
vector=vector, | |
metadata=metadata, | |
operation=to_proto_operation(submit_record["operation"]), | |
collection_id=submit_record["collection_id"].hex, | |
) | |
def from_proto_vector_embedding_record( | |
embedding_record: proto.VectorEmbeddingRecord, | |
) -> VectorEmbeddingRecord: | |
return VectorEmbeddingRecord( | |
id=embedding_record.id, | |
seq_id=from_proto_seq_id(embedding_record.seq_id), | |
embedding=from_proto_vector(embedding_record.vector)[0], | |
) | |
def to_proto_vector_embedding_record( | |
embedding_record: VectorEmbeddingRecord, | |
encoding: ScalarEncoding, | |
) -> proto.VectorEmbeddingRecord: | |
return proto.VectorEmbeddingRecord( | |
id=embedding_record["id"], | |
seq_id=to_proto_seq_id(embedding_record["seq_id"]), | |
vector=to_proto_vector(embedding_record["embedding"], encoding), | |
) | |
def from_proto_vector_query_result( | |
vector_query_result: proto.VectorQueryResult, | |
) -> VectorQueryResult: | |
return VectorQueryResult( | |
id=vector_query_result.id, | |
seq_id=from_proto_seq_id(vector_query_result.seq_id), | |
distance=vector_query_result.distance, | |
embedding=from_proto_vector(vector_query_result.vector)[0], | |
) | |
def to_proto_seq_id(seq_id: SeqId) -> bytes: | |
return int_to_bytes(seq_id) | |
def from_proto_seq_id(seq_id: bytes) -> SeqId: | |
return bytes_to_int(seq_id) | |