Update
I've updated this model to be compatible with Fastembed.
I removed the sentence_embedding
output and quantized the main model output instead. This now outputs a dimension 768 multivector.
To use the output you should use CLS pooling with normalization disabled.
snowflake2_m_uint8
This is a slightly modified version of the uint8 quantized ONNX model from https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0
I have added a linear quantization node before the token_embeddings
output so that it directly outputs a dimension 768 uint8 multivector.
This is compatible with the qdrant uint8 datatype for collections.
I took the liberty of removing the sentence_embedding
output (since I would've had to re-create it), I can add it back in if anybody wants it.
Quantization method
Linear quantization for the scale -7 to 7.
Here's what the graph of the original output looks like:
Here's what the new graph in this model looks like:
Benchmark
I used beir-qdrant with the scifact dataset.
quantized output (this model):
ndcg: {'NDCG@1': 0.59333, 'NDCG@3': 0.64619, 'NDCG@5': 0.6687, 'NDCG@10': 0.69228, 'NDCG@100': 0.72204, 'NDCG@1000': 0.72747}
recall: {'Recall@1': 0.56094, 'Recall@3': 0.68394, 'Recall@5': 0.73983, 'Recall@10': 0.80689, 'Recall@100': 0.94833, 'Recall@1000': 0.99333}
precision: {'P@1': 0.59333, 'P@3': 0.25, 'P@5': 0.16467, 'P@10': 0.09167, 'P@100': 0.01077, 'P@1000': 0.00112}
unquantized output (model_uint8.onnx):
ndcg: {'NDCG@1': 0.59333, 'NDCG@3': 0.65417, 'NDCG@5': 0.6741, 'NDCG@10': 0.69675, 'NDCG@100': 0.7242, 'NDCG@1000': 0.7305}
recall: {'Recall@1': 0.56094, 'Recall@3': 0.69728, 'Recall@5': 0.74817, 'Recall@10': 0.81356, 'Recall@100': 0.945, 'Recall@1000': 0.99667}
precision: {'P@1': 0.59333, 'P@3': 0.25444, 'P@5': 0.16667, 'P@10': 0.09233, 'P@100': 0.01073, 'P@1000': 0.00113}
Example inference/benchmark code and how to use the model with Fastembed
After installing beir-qdrant make sure to upgrade fastembed.
# pip install qdrant_client beir-qdrant
# pip install -U fastembed
from fastembed import TextEmbedding
from fastembed.common.model_description import PoolingType, ModelSource
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from qdrant_client import QdrantClient
from qdrant_client.models import Datatype
from beir_qdrant.retrieval.models.fastembed import DenseFastEmbedModelAdapter
from beir_qdrant.retrieval.search.dense import DenseQdrantSearch
TextEmbedding.add_custom_model(
model="electroglyph/snowflake2_m_uint8",
pooling=PoolingType.CLS,
normalization=False,
sources=ModelSource(hf="electroglyph/snowflake2_m_uint8"),
dim=768,
model_file="snowflake2_m_uint8.onnx",
)
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
data_path = util.download_and_unzip(url, "datasets")
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
qdrant_client = QdrantClient("http://localhost:6333")
model = DenseQdrantSearch(
qdrant_client,
model=DenseFastEmbedModelAdapter(
model_name="electroglyph/snowflake2_m_uint8"
),
collection_name="scifact-uint8",
initialize=True,
datatype=Datatype.UINT8
)
retriever = EvaluateRetrieval(model)
results = retriever.retrieve(corpus, queries)
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
print(f"ndcg: {ndcg}\nrecall: {recall}\nprecision: {precision}")
- Downloads last month
- 9