Upload 15 files
Browse files- 1_Pooling/config.json +10 -0
- 2_Dense/config.json +6 -0
- 2_Dense/model.safetensors +3 -0
- README.md +191 -3
- config.json +25 -0
- config_sentence_transformers.json +14 -0
- logo.webp +0 -0
- model.safetensors +3 -0
- modules.json +20 -0
- sentence_bert_config.json +7 -0
- special_tokens_map.json +37 -0
- tokenizer.json +0 -0
- tokenizer_config.json +65 -0
- transformers_example_mt.ipynb +167 -0
- vocab.txt +0 -0
1_Pooling/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 384,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
+
"pooling_mode_weightedmean_tokens": false,
|
8 |
+
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": true
|
10 |
+
}
|
2_Dense/config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"in_features": 384,
|
3 |
+
"out_features": 1024,
|
4 |
+
"bias": true,
|
5 |
+
"activation_function": "torch.nn.modules.linear.Identity"
|
6 |
+
}
|
2_Dense/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ded80c388c2553180f8ca297277ffbc9334dfbe3089dd9444b1fc50535f060f2
|
3 |
+
size 1577120
|
README.md
CHANGED
@@ -1,3 +1,191 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
base_model: microsoft/MiniLM-L6-v2
|
4 |
+
tags:
|
5 |
+
- transformers
|
6 |
+
- sentence-transformers
|
7 |
+
- sentence-similarity
|
8 |
+
- feature-extraction
|
9 |
+
- text-embeddings-inference
|
10 |
+
- information-retrieval
|
11 |
+
- knowledge-distillation
|
12 |
+
language:
|
13 |
+
- en
|
14 |
+
---
|
15 |
+
<div style="display: flex; justify-content: center;">
|
16 |
+
<div style="display: flex; align-items: center; gap: 10px;">
|
17 |
+
<img src="logo.webp" alt="MongoDB Logo" style="height: 36px; width: auto; border-radius: 4px;">
|
18 |
+
<span style="font-size: 32px; font-weight: bold">MongoDB/mdbr-leaf-mt</span>
|
19 |
+
</div>
|
20 |
+
</div>
|
21 |
+
|
22 |
+
# Introduction
|
23 |
+
|
24 |
+
`mdbr-leaf-mt` is a compact high-performance text embedding model designed for classification, clustering, semantic sentence similarity and summarization tasks.
|
25 |
+
|
26 |
+
To enable even greater efficiency, `mdbr-leaf-mt` supports [flexible asymmetric architectures](#asymmetric-retrieval-setup) and is robust to [vector quantization](#vector-quantization) and [MRL truncation](#mrl).
|
27 |
+
|
28 |
+
If you are looking to perform semantic search / information retrieval (e.g. for RAGs), please check out our [`mdbr-leaf-ir`](https://huggingface.co/MongoDB/mdbr-leaf-ir) model, which is specifically trained for these tasks.
|
29 |
+
|
30 |
+
> [!Note]
|
31 |
+
> **Note**: this model has been developed by the ML team of MongoDB Research. At the time of writing it is not used in any of MongoDB's commercial product or service offerings.
|
32 |
+
|
33 |
+
# Technical Report
|
34 |
+
|
35 |
+
A technical report detailing our proposed `LEAF` training procedure is [available here (TBD)](http://FILL_HERE_ARXIV_LINK).
|
36 |
+
|
37 |
+
# Highlights
|
38 |
+
|
39 |
+
* **State-of-the-Art Performance**: `mdbr-leaf-mt` achieves new state-of-the-art results for compact embedding models, ranking <span style="color:red">#TBD</span> on the [public MTEB v2 (Eng) benchmark leaderboard](https://huggingface.co/spaces/mteb/leaderboard) for models <30M parameters with an average score of <span style="color:red">[TBD HERE]</span>.
|
40 |
+
* **Flexible Architecture Support**: `mdbr-leaf-mt` supports asymmetric retrieval architectures enabling even greater retrieval results. [See below](#asymmetric-retrieval-setup) for more information.
|
41 |
+
* **MRL and Quantization Support**: embedding vectors generated by `mdbr-leaf-mt` compress well when truncated (MRL) and/or can be stored using more efficient types like `int8` and `binary`. [See below](#mrl) for more information.
|
42 |
+
|
43 |
+
# Quickstart
|
44 |
+
|
45 |
+
## Sentence Transformers
|
46 |
+
|
47 |
+
```python
|
48 |
+
from sentence_transformers import SentenceTransformer
|
49 |
+
|
50 |
+
# Load the model
|
51 |
+
model = SentenceTransformer("MongoDB/mdbr-leaf-mt")
|
52 |
+
|
53 |
+
# Example queries and documents
|
54 |
+
queries = [
|
55 |
+
"What is machine learning?",
|
56 |
+
"How does neural network training work?"
|
57 |
+
]
|
58 |
+
|
59 |
+
documents = [
|
60 |
+
"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.",
|
61 |
+
"Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors."
|
62 |
+
]
|
63 |
+
|
64 |
+
# Encode queries and documents
|
65 |
+
query_embeddings = model.encode(queries, prompt_name="query")
|
66 |
+
document_embeddings = model.encode(documents)
|
67 |
+
|
68 |
+
# Compute similarity scores
|
69 |
+
scores = model.similarity(query_embeddings, document_embeddings)
|
70 |
+
|
71 |
+
# Print results
|
72 |
+
for i, query in enumerate(queries):
|
73 |
+
print(f"Query: {query}")
|
74 |
+
for j, doc in enumerate(documents):
|
75 |
+
print(f" Similarity: {scores[i, j]:.4f} | Document {j}: {doc[:80]}...")
|
76 |
+
|
77 |
+
# Query: What is machine learning?
|
78 |
+
# Similarity: 0.9063 | Document 0: Machine learning is a subset of ...
|
79 |
+
# Similarity: 0.7287 | Document 1: Neural networks are trained ...
|
80 |
+
#
|
81 |
+
# Query: How does neural network training work?
|
82 |
+
# Similarity: 0.6725 | Document 0: Machine learning is a subset of ...
|
83 |
+
# Similarity: 0.8287 | Document 1: Neural networks are trained ...
|
84 |
+
```
|
85 |
+
|
86 |
+
## Transformers Usage
|
87 |
+
|
88 |
+
See [here](https://huggingface.co/MongoDB/mdbr-leaf-mt/blob/main/transformers_example_mt.ipynb).
|
89 |
+
|
90 |
+
## Asymmetric Retrieval Setup
|
91 |
+
|
92 |
+
`mdbr-leaf-mt` is *aligned* to [`mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1), the model it has been distilled from, making the asymmetric system below possible:
|
93 |
+
|
94 |
+
```python
|
95 |
+
# Use mdbr-leaf-mt for query encoding (real-time, low latency)
|
96 |
+
query_model = SentenceTransformer("MongoDB/mdbr-leaf-mt")
|
97 |
+
query_embeddings = query_model.encode(queries, prompt_name="query")
|
98 |
+
|
99 |
+
# Use a larger model for document encoding (one-time, at index time)
|
100 |
+
doc_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
101 |
+
document_embeddings = doc_model.encode(documents)
|
102 |
+
|
103 |
+
# Compute similarities
|
104 |
+
scores = query_model.similarity(query_embeddings, document_embeddings)
|
105 |
+
```
|
106 |
+
Retrieval results from asymmetric mode are usually superior to the [standard mode above](#sentence-transformers).
|
107 |
+
|
108 |
+
## MRL Truncation
|
109 |
+
|
110 |
+
Embeddings have been trained via [MRL](https://arxiv.org/abs/2205.13147) and can be truncated for more efficient storage:
|
111 |
+
```python
|
112 |
+
from torch.nn import functional as F
|
113 |
+
|
114 |
+
query_embeds = model.encode(queries, prompt_name="query", convert_to_tensor=True)
|
115 |
+
doc_embeds = model.encode(documents, convert_to_tensor=True)
|
116 |
+
|
117 |
+
# Truncate and normalize according to MRL
|
118 |
+
query_embeds = F.normalize(query_embeds[:, :256], dim=-1)
|
119 |
+
doc_embeds = F.normalize(doc_embeds[:, :256], dim=-1)
|
120 |
+
|
121 |
+
similarities = model.similarity(query_embeds, doc_embeds)
|
122 |
+
|
123 |
+
print('After MRL:')
|
124 |
+
print(f"* Embeddings dimension: {query_embeds.shape[1]}")
|
125 |
+
print(f"* Similarities:\n\t{similarities}")
|
126 |
+
|
127 |
+
# After MRL:
|
128 |
+
# * Embeddings dimension: 256
|
129 |
+
# * Similarities:
|
130 |
+
# tensor([[0.9164, 0.7219],
|
131 |
+
# [0.6682, 0.8393]], device='cuda:0')
|
132 |
+
```
|
133 |
+
|
134 |
+
## Vector Quantization
|
135 |
+
Vector quantization, for example to `int8` or `binary`, can be performed as follows:
|
136 |
+
|
137 |
+
**Note**: For vector quantization to types other than binary, we suggest performing a calibration to determine the optimal ranges, [see here](https://sbert.net/examples/sentence_transformer/applications/embedding-quantization/README.html#scalar-int8-quantization).
|
138 |
+
Good initial values are -1.0 and +1.0.
|
139 |
+
```python
|
140 |
+
from sentence_transformers.quantization import quantize_embeddings
|
141 |
+
import torch
|
142 |
+
|
143 |
+
query_embeds = model.encode(queries, prompt_name="query")
|
144 |
+
doc_embeds = model.encode(documents)
|
145 |
+
|
146 |
+
# Quantize embeddings to int8 using -1.0 and +1.0
|
147 |
+
ranges = torch.tensor([[-1.0], [+1.0]]).expand(2, query_embeds.shape[1]).cpu().numpy()
|
148 |
+
query_embeds = quantize_embeddings(query_embeds, "int8", ranges=ranges)
|
149 |
+
doc_embeds = quantize_embeddings(doc_embeds, "int8", ranges=ranges)
|
150 |
+
|
151 |
+
# Calculate similarities; cast to int64 to avoid under/overflow
|
152 |
+
similarities = query_embeds.astype(int) @ doc_embeds.astype(int).T
|
153 |
+
|
154 |
+
print('After quantization:')
|
155 |
+
print(f"* Embeddings type: {query_embeds.dtype}")
|
156 |
+
print(f"* Similarities:\n{similarities}")
|
157 |
+
|
158 |
+
# After quantization:
|
159 |
+
# * Embeddings type: int8
|
160 |
+
# * Similarities:
|
161 |
+
# [[2202032 1422868]
|
162 |
+
# [1421197 1845580]]
|
163 |
+
```
|
164 |
+
|
165 |
+
# Evaluation
|
166 |
+
|
167 |
+
The checkpoint used to produce the scores presented in the paper [is here](https://huggingface.co/MongoDB/mdbr-leaf-mt/commit/ea98995e96beac21b820aa8ad9afaa6fd29b243d).
|
168 |
+
|
169 |
+
# Citation
|
170 |
+
|
171 |
+
If you use this model in your work, please cite:
|
172 |
+
|
173 |
+
```bibtex
|
174 |
+
@article{mdb_leaf,
|
175 |
+
title = {LEAF: Lightweight Embedding Alignment Knowledge Distillation Framework},
|
176 |
+
author = {Robin Vujanic and Thomas Rueckstiess},
|
177 |
+
year = {2025}
|
178 |
+
eprint = {TBD},
|
179 |
+
archiveprefix = {arXiv},
|
180 |
+
primaryclass = {FILL HERE},
|
181 |
+
url = {FILL HERE}
|
182 |
+
}
|
183 |
+
```
|
184 |
+
|
185 |
+
# License
|
186 |
+
|
187 |
+
This model is released under Apache 2.0 License.
|
188 |
+
|
189 |
+
# Contact
|
190 |
+
|
191 |
+
For questions or issues, please open an issue or pull request. You can also contact the MongoDB ML Research team at [email protected].
|
config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"classifier_dropout": null,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 384,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"intermediate_size": 1536,
|
13 |
+
"layer_norm_eps": 1e-12,
|
14 |
+
"max_position_embeddings": 512,
|
15 |
+
"model_type": "bert",
|
16 |
+
"num_attention_heads": 12,
|
17 |
+
"num_hidden_layers": 6,
|
18 |
+
"pad_token_id": 0,
|
19 |
+
"position_embedding_type": "absolute",
|
20 |
+
"torch_dtype": "float32",
|
21 |
+
"transformers_version": "4.52.4",
|
22 |
+
"type_vocab_size": 2,
|
23 |
+
"use_cache": true,
|
24 |
+
"vocab_size": 30522
|
25 |
+
}
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "SentenceTransformer",
|
3 |
+
"__version__": {
|
4 |
+
"sentence_transformers": "5.1.0",
|
5 |
+
"transformers": "4.52.4",
|
6 |
+
"pytorch": "2.6.0+cu126"
|
7 |
+
},
|
8 |
+
"prompts": {
|
9 |
+
"query": "Represent this sentence for searching relevant passages: ",
|
10 |
+
"document": ""
|
11 |
+
},
|
12 |
+
"default_prompt_name": null,
|
13 |
+
"similarity_fn_name": "cosine"
|
14 |
+
}
|
logo.webp
ADDED
![]() |
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f80e7c12db001fca76f7b85132fa616cb8b8f9bb11c83ecac8c1779420b65d0
|
3 |
+
size 90272656
|
modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Dense",
|
18 |
+
"type": "sentence_transformers.models.Dense"
|
19 |
+
}
|
20 |
+
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 512,
|
3 |
+
"do_lower_case": false,
|
4 |
+
"model_args": {
|
5 |
+
"add_pooling_layer": false
|
6 |
+
}
|
7 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": {
|
3 |
+
"content": "[CLS]",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"mask_token": {
|
10 |
+
"content": "[MASK]",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "[PAD]",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"sep_token": {
|
24 |
+
"content": "[SEP]",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"unk_token": {
|
31 |
+
"content": "[UNK]",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
}
|
37 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": false,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_basic_tokenize": true,
|
47 |
+
"do_lower_case": true,
|
48 |
+
"extra_special_tokens": {},
|
49 |
+
"mask_token": "[MASK]",
|
50 |
+
"max_length": 128,
|
51 |
+
"model_max_length": 512,
|
52 |
+
"never_split": null,
|
53 |
+
"pad_to_multiple_of": null,
|
54 |
+
"pad_token": "[PAD]",
|
55 |
+
"pad_token_type_id": 0,
|
56 |
+
"padding_side": "right",
|
57 |
+
"sep_token": "[SEP]",
|
58 |
+
"stride": 0,
|
59 |
+
"strip_accents": null,
|
60 |
+
"tokenize_chinese_chars": true,
|
61 |
+
"tokenizer_class": "BertTokenizer",
|
62 |
+
"truncation_side": "right",
|
63 |
+
"truncation_strategy": "longest_first",
|
64 |
+
"unk_token": "[UNK]"
|
65 |
+
}
|
transformers_example_mt.ipynb
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "2a12a2b3",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from safetensors import safe_open\n",
|
11 |
+
"import torch\n",
|
12 |
+
"from torch.nn import functional as F\n",
|
13 |
+
"from transformers import AutoModel, AutoTokenizer"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "code",
|
18 |
+
"execution_count": null,
|
19 |
+
"id": "148ce181",
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"# First clone the model locally\n",
|
24 |
+
"!git clone https://huggingface.co/MongoDB/mdbr-leaf-mt"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 2,
|
30 |
+
"id": "ba9ec6c7",
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"name": "stderr",
|
35 |
+
"output_type": "stream",
|
36 |
+
"text": [
|
37 |
+
"Some weights of BertModel were not initialized from the model checkpoint at mdbr-leaf-mt and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n",
|
38 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
39 |
+
]
|
40 |
+
}
|
41 |
+
],
|
42 |
+
"source": [
|
43 |
+
"# Then load it\n",
|
44 |
+
"MODEL = \"mdbr-leaf-mt\"\n",
|
45 |
+
"\n",
|
46 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
|
47 |
+
"model = AutoModel.from_pretrained(MODEL)"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": null,
|
53 |
+
"id": "ebaf1a76",
|
54 |
+
"metadata": {},
|
55 |
+
"outputs": [],
|
56 |
+
"source": [
|
57 |
+
"tensors = {}\n",
|
58 |
+
"with safe_open(MODEL + \"/2_Dense/model.safetensors\", framework=\"pt\") as f:\n",
|
59 |
+
" for k in f.keys():\n",
|
60 |
+
" tensors[k] = f.get_tensor(k)"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": 21,
|
66 |
+
"id": "03ffcd9c",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [
|
69 |
+
{
|
70 |
+
"name": "stdout",
|
71 |
+
"output_type": "stream",
|
72 |
+
"text": [
|
73 |
+
"Similarities:\n",
|
74 |
+
"tensor([[0.9063, 0.7287],\n",
|
75 |
+
" [0.6725, 0.8287]])\n"
|
76 |
+
]
|
77 |
+
}
|
78 |
+
],
|
79 |
+
"source": [
|
80 |
+
"if 'linear.bias' in tensors:\n",
|
81 |
+
" W_out = torch.nn.Linear(in_features=384, out_features=1024, bias=True)\n",
|
82 |
+
" W_out.load_state_dict({\n",
|
83 |
+
" \"weight\": tensors[\"linear.weight\"], \n",
|
84 |
+
" \"bias\": tensors[\"linear.bias\"]\n",
|
85 |
+
" })\n",
|
86 |
+
"else:\n",
|
87 |
+
" W_out = torch.nn.Linear(in_features=384, out_features=1024, bias=False)\n",
|
88 |
+
" W_out.load_state_dict({\n",
|
89 |
+
" \"weight\": tensors[\"linear.weight\"]\n",
|
90 |
+
" })\n",
|
91 |
+
"\n",
|
92 |
+
"_ = model.eval()\n",
|
93 |
+
"_ = W_out.eval()\n",
|
94 |
+
"\n",
|
95 |
+
"# Example queries and documents \n",
|
96 |
+
"queries = [\n",
|
97 |
+
" \"What is machine learning?\", \n",
|
98 |
+
" \"How does neural network training work?\" \n",
|
99 |
+
"] \n",
|
100 |
+
" \n",
|
101 |
+
"documents = [ \n",
|
102 |
+
" \"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\", \n",
|
103 |
+
" \"Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors.\" \n",
|
104 |
+
"]\n",
|
105 |
+
"\n",
|
106 |
+
"# Tokenize\n",
|
107 |
+
"QUERY_PREFIX = 'Represent this sentence for searching relevant passages: '\n",
|
108 |
+
"queries_with_prefix = [QUERY_PREFIX + query for query in queries]\n",
|
109 |
+
"\n",
|
110 |
+
"query_tokens = tokenizer(queries_with_prefix, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
|
111 |
+
"document_tokens = tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
|
112 |
+
"\n",
|
113 |
+
"# Perform Inference\n",
|
114 |
+
"with torch.inference_mode():\n",
|
115 |
+
" y_queries = model(**query_tokens).last_hidden_state\n",
|
116 |
+
" y_docs = model(**document_tokens).last_hidden_state\n",
|
117 |
+
"\n",
|
118 |
+
" # perform pooling\n",
|
119 |
+
" y_queries = y_queries * query_tokens.attention_mask.unsqueeze(-1)\n",
|
120 |
+
" y_queries_pooled = y_queries.sum(dim=1) / query_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
|
121 |
+
"\n",
|
122 |
+
" y_docs = y_docs * document_tokens.attention_mask.unsqueeze(-1)\n",
|
123 |
+
" y_docs_pooled = y_docs.sum(dim=1) / document_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
|
124 |
+
"\n",
|
125 |
+
" # map to desired output dimension\n",
|
126 |
+
" query_embeddings = W_out(y_queries_pooled)\n",
|
127 |
+
" document_embeddings = W_out(y_docs_pooled)\n",
|
128 |
+
"\n",
|
129 |
+
"similarities = F.cosine_similarity(query_embeddings.unsqueeze(0), document_embeddings.unsqueeze(1), dim=-1).T\n",
|
130 |
+
"print(f\"Similarities:\\n{similarities}\")\n",
|
131 |
+
"\n",
|
132 |
+
"# Similarities:\n",
|
133 |
+
"# tensor([[0.6857, 0.4598],\n",
|
134 |
+
"# [0.4238, 0.5723]])"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": null,
|
140 |
+
"id": "5a2b0244",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": []
|
144 |
+
}
|
145 |
+
],
|
146 |
+
"metadata": {
|
147 |
+
"kernelspec": {
|
148 |
+
"display_name": "alexis",
|
149 |
+
"language": "python",
|
150 |
+
"name": "python3"
|
151 |
+
},
|
152 |
+
"language_info": {
|
153 |
+
"codemirror_mode": {
|
154 |
+
"name": "ipython",
|
155 |
+
"version": 3
|
156 |
+
},
|
157 |
+
"file_extension": ".py",
|
158 |
+
"mimetype": "text/x-python",
|
159 |
+
"name": "python",
|
160 |
+
"nbconvert_exporter": "python",
|
161 |
+
"pygments_lexer": "ipython3",
|
162 |
+
"version": "3.12.7"
|
163 |
+
}
|
164 |
+
},
|
165 |
+
"nbformat": 4,
|
166 |
+
"nbformat_minor": 5
|
167 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|