chroma / chromadb /proto /convert.py
badalsahani's picture
feat: chroma initial deploy
287a0bc
import array
from uuid import UUID
from typing import Dict, Optional, Tuple, Union, cast
from chromadb.api.types import Embedding
import chromadb.proto.chroma_pb2 as proto
from chromadb.utils.messageid import bytes_to_int, int_to_bytes
from chromadb.types import (
Collection,
EmbeddingRecord,
Metadata,
Operation,
ScalarEncoding,
Segment,
SegmentScope,
SeqId,
SubmitEmbeddingRecord,
UpdateMetadata,
Vector,
VectorEmbeddingRecord,
VectorQueryResult,
)
# TODO: Unit tests for this file, handling optional states etc
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector:
if encoding == ScalarEncoding.FLOAT32:
as_bytes = array.array("f", vector).tobytes()
proto_encoding = proto.ScalarEncoding.FLOAT32
elif encoding == ScalarEncoding.INT32:
as_bytes = array.array("i", vector).tobytes()
proto_encoding = proto.ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
or {ScalarEncoding.INT32}"
)
return proto.Vector(dimension=len(vector), vector=as_bytes, encoding=proto_encoding)
def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]:
encoding = vector.encoding
as_array: Union[array.array[float], array.array[int]]
if encoding == proto.ScalarEncoding.FLOAT32:
as_array = array.array("f")
out_encoding = ScalarEncoding.FLOAT32
elif encoding == proto.ScalarEncoding.INT32:
as_array = array.array("i")
out_encoding = ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of \
{proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}"
)
as_array.frombytes(vector.vector)
return (as_array.tolist(), out_encoding)
def from_proto_operation(operation: proto.Operation) -> Operation:
if operation == proto.Operation.ADD:
return Operation.ADD
elif operation == proto.Operation.UPDATE:
return Operation.UPDATE
elif operation == proto.Operation.UPSERT:
return Operation.UPSERT
elif operation == proto.Operation.DELETE:
return Operation.DELETE
else:
# TODO: full error
raise RuntimeError(f"Unknown operation {operation}")
def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]:
return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False))
def from_proto_update_metadata(
metadata: proto.UpdateMetadata,
) -> Optional[UpdateMetadata]:
return cast(
Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
)
def _from_proto_metadata_handle_none(
metadata: proto.UpdateMetadata, is_update: bool
) -> Optional[Union[UpdateMetadata, Metadata]]:
if not metadata.metadata:
return None
out_metadata: Dict[str, Union[str, int, float, None]] = {}
for key, value in metadata.metadata.items():
if value.HasField("string_value"):
out_metadata[key] = value.string_value
elif value.HasField("int_value"):
out_metadata[key] = value.int_value
elif value.HasField("float_value"):
out_metadata[key] = value.float_value
elif is_update:
out_metadata[key] = None
else:
raise ValueError(f"Metadata key {key} value cannot be None")
return out_metadata
def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata:
return proto.UpdateMetadata(
metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()}
)
def from_proto_submit(
submit_embedding_record: proto.SubmitEmbeddingRecord, seq_id: SeqId
) -> EmbeddingRecord:
embedding, encoding = from_proto_vector(submit_embedding_record.vector)
record = EmbeddingRecord(
id=submit_embedding_record.id,
seq_id=seq_id,
embedding=embedding,
encoding=encoding,
metadata=from_proto_update_metadata(submit_embedding_record.metadata),
operation=from_proto_operation(submit_embedding_record.operation),
collection_id=UUID(hex=submit_embedding_record.collection_id),
)
return record
def from_proto_segment(segment: proto.Segment) -> Segment:
return Segment(
id=UUID(hex=segment.id),
type=segment.type,
scope=from_proto_segment_scope(segment.scope),
topic=segment.topic if segment.HasField("topic") else None,
collection=None
if not segment.HasField("collection")
else UUID(hex=segment.collection),
metadata=from_proto_metadata(segment.metadata)
if segment.HasField("metadata")
else None,
)
def to_proto_segment(segment: Segment) -> proto.Segment:
return proto.Segment(
id=segment["id"].hex,
type=segment["type"],
scope=to_proto_segment_scope(segment["scope"]),
topic=segment["topic"],
collection=None if segment["collection"] is None else segment["collection"].hex,
metadata=None
if segment["metadata"] is None
else to_proto_update_metadata(segment["metadata"]),
)
def from_proto_segment_scope(segment_scope: proto.SegmentScope) -> SegmentScope:
if segment_scope == proto.SegmentScope.VECTOR:
return SegmentScope.VECTOR
elif segment_scope == proto.SegmentScope.METADATA:
return SegmentScope.METADATA
else:
raise RuntimeError(f"Unknown segment scope {segment_scope}")
def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope:
if segment_scope == SegmentScope.VECTOR:
return proto.SegmentScope.VECTOR
elif segment_scope == SegmentScope.METADATA:
return proto.SegmentScope.METADATA
else:
raise RuntimeError(f"Unknown segment scope {segment_scope}")
def to_proto_metadata_update_value(
value: Union[str, int, float, None]
) -> proto.UpdateMetadataValue:
if isinstance(value, str):
return proto.UpdateMetadataValue(string_value=value)
elif isinstance(value, int):
return proto.UpdateMetadataValue(int_value=value)
elif isinstance(value, float):
return proto.UpdateMetadataValue(float_value=value)
elif value is None:
return proto.UpdateMetadataValue()
else:
raise ValueError(
f"Unknown metadata value type {type(value)}, expected one of str, int, \
float, or None"
)
def from_proto_collection(collection: proto.Collection) -> Collection:
return Collection(
id=UUID(hex=collection.id),
name=collection.name,
topic=collection.topic,
metadata=from_proto_metadata(collection.metadata)
if collection.HasField("metadata")
else None,
dimension=collection.dimension
if collection.HasField("dimension") and collection.dimension
else None,
database=collection.database,
tenant=collection.tenant,
)
def to_proto_collection(collection: Collection) -> proto.Collection:
return proto.Collection(
id=collection["id"].hex,
name=collection["name"],
topic=collection["topic"],
metadata=None
if collection["metadata"] is None
else to_proto_update_metadata(collection["metadata"]),
dimension=collection["dimension"],
tenant=collection["tenant"],
database=collection["database"],
)
def to_proto_operation(operation: Operation) -> proto.Operation:
if operation == Operation.ADD:
return proto.Operation.ADD
elif operation == Operation.UPDATE:
return proto.Operation.UPDATE
elif operation == Operation.UPSERT:
return proto.Operation.UPSERT
elif operation == Operation.DELETE:
return proto.Operation.DELETE
else:
raise ValueError(
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
)
def to_proto_submit(
submit_record: SubmitEmbeddingRecord,
) -> proto.SubmitEmbeddingRecord:
vector = None
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
metadata = None
if submit_record["metadata"] is not None:
metadata = to_proto_update_metadata(submit_record["metadata"])
return proto.SubmitEmbeddingRecord(
id=submit_record["id"],
vector=vector,
metadata=metadata,
operation=to_proto_operation(submit_record["operation"]),
collection_id=submit_record["collection_id"].hex,
)
def from_proto_vector_embedding_record(
embedding_record: proto.VectorEmbeddingRecord,
) -> VectorEmbeddingRecord:
return VectorEmbeddingRecord(
id=embedding_record.id,
seq_id=from_proto_seq_id(embedding_record.seq_id),
embedding=from_proto_vector(embedding_record.vector)[0],
)
def to_proto_vector_embedding_record(
embedding_record: VectorEmbeddingRecord,
encoding: ScalarEncoding,
) -> proto.VectorEmbeddingRecord:
return proto.VectorEmbeddingRecord(
id=embedding_record["id"],
seq_id=to_proto_seq_id(embedding_record["seq_id"]),
vector=to_proto_vector(embedding_record["embedding"], encoding),
)
def from_proto_vector_query_result(
vector_query_result: proto.VectorQueryResult,
) -> VectorQueryResult:
return VectorQueryResult(
id=vector_query_result.id,
seq_id=from_proto_seq_id(vector_query_result.seq_id),
distance=vector_query_result.distance,
embedding=from_proto_vector(vector_query_result.vector)[0],
)
def to_proto_seq_id(seq_id: SeqId) -> bytes:
return int_to_bytes(seq_id)
def from_proto_seq_id(seq_id: bytes) -> SeqId:
return bytes_to_int(seq_id)