rvo commited on
Commit
c342f94
·
verified ·
1 Parent(s): 943b9dc

Upload 15 files

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 384,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
2_Dense/config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "in_features": 384,
3
+ "out_features": 1024,
4
+ "bias": true,
5
+ "activation_function": "torch.nn.modules.linear.Identity"
6
+ }
2_Dense/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ded80c388c2553180f8ca297277ffbc9334dfbe3089dd9444b1fc50535f060f2
3
+ size 1577120
README.md CHANGED
@@ -1,3 +1,191 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model: microsoft/MiniLM-L6-v2
4
+ tags:
5
+ - transformers
6
+ - sentence-transformers
7
+ - sentence-similarity
8
+ - feature-extraction
9
+ - text-embeddings-inference
10
+ - information-retrieval
11
+ - knowledge-distillation
12
+ language:
13
+ - en
14
+ ---
15
+ <div style="display: flex; justify-content: center;">
16
+ <div style="display: flex; align-items: center; gap: 10px;">
17
+ <img src="logo.webp" alt="MongoDB Logo" style="height: 36px; width: auto; border-radius: 4px;">
18
+ <span style="font-size: 32px; font-weight: bold">MongoDB/mdbr-leaf-mt</span>
19
+ </div>
20
+ </div>
21
+
22
+ # Introduction
23
+
24
+ `mdbr-leaf-mt` is a compact high-performance text embedding model designed for classification, clustering, semantic sentence similarity and summarization tasks.
25
+
26
+ To enable even greater efficiency, `mdbr-leaf-mt` supports [flexible asymmetric architectures](#asymmetric-retrieval-setup) and is robust to [vector quantization](#vector-quantization) and [MRL truncation](#mrl).
27
+
28
+ If you are looking to perform semantic search / information retrieval (e.g. for RAGs), please check out our [`mdbr-leaf-ir`](https://huggingface.co/MongoDB/mdbr-leaf-ir) model, which is specifically trained for these tasks.
29
+
30
+ > [!Note]
31
+ > **Note**: this model has been developed by the ML team of MongoDB Research. At the time of writing it is not used in any of MongoDB's commercial product or service offerings.
32
+
33
+ # Technical Report
34
+
35
+ A technical report detailing our proposed `LEAF` training procedure is [available here (TBD)](http://FILL_HERE_ARXIV_LINK).
36
+
37
+ # Highlights
38
+
39
+ * **State-of-the-Art Performance**: `mdbr-leaf-mt` achieves new state-of-the-art results for compact embedding models, ranking <span style="color:red">#TBD</span> on the [public MTEB v2 (Eng) benchmark leaderboard](https://huggingface.co/spaces/mteb/leaderboard) for models <30M parameters with an average score of <span style="color:red">[TBD HERE]</span>.
40
+ * **Flexible Architecture Support**: `mdbr-leaf-mt` supports asymmetric retrieval architectures enabling even greater retrieval results. [See below](#asymmetric-retrieval-setup) for more information.
41
+ * **MRL and Quantization Support**: embedding vectors generated by `mdbr-leaf-mt` compress well when truncated (MRL) and/or can be stored using more efficient types like `int8` and `binary`. [See below](#mrl) for more information.
42
+
43
+ # Quickstart
44
+
45
+ ## Sentence Transformers
46
+
47
+ ```python
48
+ from sentence_transformers import SentenceTransformer
49
+
50
+ # Load the model
51
+ model = SentenceTransformer("MongoDB/mdbr-leaf-mt")
52
+
53
+ # Example queries and documents
54
+ queries = [
55
+ "What is machine learning?",
56
+ "How does neural network training work?"
57
+ ]
58
+
59
+ documents = [
60
+ "Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.",
61
+ "Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors."
62
+ ]
63
+
64
+ # Encode queries and documents
65
+ query_embeddings = model.encode(queries, prompt_name="query")
66
+ document_embeddings = model.encode(documents)
67
+
68
+ # Compute similarity scores
69
+ scores = model.similarity(query_embeddings, document_embeddings)
70
+
71
+ # Print results
72
+ for i, query in enumerate(queries):
73
+ print(f"Query: {query}")
74
+ for j, doc in enumerate(documents):
75
+ print(f" Similarity: {scores[i, j]:.4f} | Document {j}: {doc[:80]}...")
76
+
77
+ # Query: What is machine learning?
78
+ # Similarity: 0.9063 | Document 0: Machine learning is a subset of ...
79
+ # Similarity: 0.7287 | Document 1: Neural networks are trained ...
80
+ #
81
+ # Query: How does neural network training work?
82
+ # Similarity: 0.6725 | Document 0: Machine learning is a subset of ...
83
+ # Similarity: 0.8287 | Document 1: Neural networks are trained ...
84
+ ```
85
+
86
+ ## Transformers Usage
87
+
88
+ See [here](https://huggingface.co/MongoDB/mdbr-leaf-mt/blob/main/transformers_example_mt.ipynb).
89
+
90
+ ## Asymmetric Retrieval Setup
91
+
92
+ `mdbr-leaf-mt` is *aligned* to [`mxbai-embed-large-v1`](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1), the model it has been distilled from, making the asymmetric system below possible:
93
+
94
+ ```python
95
+ # Use mdbr-leaf-mt for query encoding (real-time, low latency)
96
+ query_model = SentenceTransformer("MongoDB/mdbr-leaf-mt")
97
+ query_embeddings = query_model.encode(queries, prompt_name="query")
98
+
99
+ # Use a larger model for document encoding (one-time, at index time)
100
+ doc_model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
101
+ document_embeddings = doc_model.encode(documents)
102
+
103
+ # Compute similarities
104
+ scores = query_model.similarity(query_embeddings, document_embeddings)
105
+ ```
106
+ Retrieval results from asymmetric mode are usually superior to the [standard mode above](#sentence-transformers).
107
+
108
+ ## MRL Truncation
109
+
110
+ Embeddings have been trained via [MRL](https://arxiv.org/abs/2205.13147) and can be truncated for more efficient storage:
111
+ ```python
112
+ from torch.nn import functional as F
113
+
114
+ query_embeds = model.encode(queries, prompt_name="query", convert_to_tensor=True)
115
+ doc_embeds = model.encode(documents, convert_to_tensor=True)
116
+
117
+ # Truncate and normalize according to MRL
118
+ query_embeds = F.normalize(query_embeds[:, :256], dim=-1)
119
+ doc_embeds = F.normalize(doc_embeds[:, :256], dim=-1)
120
+
121
+ similarities = model.similarity(query_embeds, doc_embeds)
122
+
123
+ print('After MRL:')
124
+ print(f"* Embeddings dimension: {query_embeds.shape[1]}")
125
+ print(f"* Similarities:\n\t{similarities}")
126
+
127
+ # After MRL:
128
+ # * Embeddings dimension: 256
129
+ # * Similarities:
130
+ # tensor([[0.9164, 0.7219],
131
+ # [0.6682, 0.8393]], device='cuda:0')
132
+ ```
133
+
134
+ ## Vector Quantization
135
+ Vector quantization, for example to `int8` or `binary`, can be performed as follows:
136
+
137
+ **Note**: For vector quantization to types other than binary, we suggest performing a calibration to determine the optimal ranges, [see here](https://sbert.net/examples/sentence_transformer/applications/embedding-quantization/README.html#scalar-int8-quantization).
138
+ Good initial values are -1.0 and +1.0.
139
+ ```python
140
+ from sentence_transformers.quantization import quantize_embeddings
141
+ import torch
142
+
143
+ query_embeds = model.encode(queries, prompt_name="query")
144
+ doc_embeds = model.encode(documents)
145
+
146
+ # Quantize embeddings to int8 using -1.0 and +1.0
147
+ ranges = torch.tensor([[-1.0], [+1.0]]).expand(2, query_embeds.shape[1]).cpu().numpy()
148
+ query_embeds = quantize_embeddings(query_embeds, "int8", ranges=ranges)
149
+ doc_embeds = quantize_embeddings(doc_embeds, "int8", ranges=ranges)
150
+
151
+ # Calculate similarities; cast to int64 to avoid under/overflow
152
+ similarities = query_embeds.astype(int) @ doc_embeds.astype(int).T
153
+
154
+ print('After quantization:')
155
+ print(f"* Embeddings type: {query_embeds.dtype}")
156
+ print(f"* Similarities:\n{similarities}")
157
+
158
+ # After quantization:
159
+ # * Embeddings type: int8
160
+ # * Similarities:
161
+ # [[2202032 1422868]
162
+ # [1421197 1845580]]
163
+ ```
164
+
165
+ # Evaluation
166
+
167
+ The checkpoint used to produce the scores presented in the paper [is here](https://huggingface.co/MongoDB/mdbr-leaf-mt/commit/ea98995e96beac21b820aa8ad9afaa6fd29b243d).
168
+
169
+ # Citation
170
+
171
+ If you use this model in your work, please cite:
172
+
173
+ ```bibtex
174
+ @article{mdb_leaf,
175
+ title = {LEAF: Lightweight Embedding Alignment Knowledge Distillation Framework},
176
+ author = {Robin Vujanic and Thomas Rueckstiess},
177
+ year = {2025}
178
+ eprint = {TBD},
179
+ archiveprefix = {arXiv},
180
+ primaryclass = {FILL HERE},
181
+ url = {FILL HERE}
182
+ }
183
+ ```
184
+
185
+ # License
186
+
187
+ This model is released under Apache 2.0 License.
188
+
189
+ # Contact
190
+
191
+ For questions or issues, please open an issue or pull request. You can also contact the MongoDB ML Research team at [email protected].
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "classifier_dropout": null,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 384,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1536,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 6,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.52.4",
22
+ "type_vocab_size": 2,
23
+ "use_cache": true,
24
+ "vocab_size": 30522
25
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SentenceTransformer",
3
+ "__version__": {
4
+ "sentence_transformers": "5.1.0",
5
+ "transformers": "4.52.4",
6
+ "pytorch": "2.6.0+cu126"
7
+ },
8
+ "prompts": {
9
+ "query": "Represent this sentence for searching relevant passages: ",
10
+ "document": ""
11
+ },
12
+ "default_prompt_name": null,
13
+ "similarity_fn_name": "cosine"
14
+ }
logo.webp ADDED
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f80e7c12db001fca76f7b85132fa616cb8b8f9bb11c83ecac8c1779420b65d0
3
+ size 90272656
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Dense",
18
+ "type": "sentence_transformers.models.Dense"
19
+ }
20
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 512,
3
+ "do_lower_case": false,
4
+ "model_args": {
5
+ "add_pooling_layer": false
6
+ }
7
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "max_length": 128,
51
+ "model_max_length": 512,
52
+ "never_split": null,
53
+ "pad_to_multiple_of": null,
54
+ "pad_token": "[PAD]",
55
+ "pad_token_type_id": 0,
56
+ "padding_side": "right",
57
+ "sep_token": "[SEP]",
58
+ "stride": 0,
59
+ "strip_accents": null,
60
+ "tokenize_chinese_chars": true,
61
+ "tokenizer_class": "BertTokenizer",
62
+ "truncation_side": "right",
63
+ "truncation_strategy": "longest_first",
64
+ "unk_token": "[UNK]"
65
+ }
transformers_example_mt.ipynb ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "2a12a2b3",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from safetensors import safe_open\n",
11
+ "import torch\n",
12
+ "from torch.nn import functional as F\n",
13
+ "from transformers import AutoModel, AutoTokenizer"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "id": "148ce181",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# First clone the model locally\n",
24
+ "!git clone https://huggingface.co/MongoDB/mdbr-leaf-mt"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "id": "ba9ec6c7",
31
+ "metadata": {},
32
+ "outputs": [
33
+ {
34
+ "name": "stderr",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "Some weights of BertModel were not initialized from the model checkpoint at mdbr-leaf-mt and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n",
38
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
39
+ ]
40
+ }
41
+ ],
42
+ "source": [
43
+ "# Then load it\n",
44
+ "MODEL = \"mdbr-leaf-mt\"\n",
45
+ "\n",
46
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
47
+ "model = AutoModel.from_pretrained(MODEL)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "ebaf1a76",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "tensors = {}\n",
58
+ "with safe_open(MODEL + \"/2_Dense/model.safetensors\", framework=\"pt\") as f:\n",
59
+ " for k in f.keys():\n",
60
+ " tensors[k] = f.get_tensor(k)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": 21,
66
+ "id": "03ffcd9c",
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "Similarities:\n",
74
+ "tensor([[0.9063, 0.7287],\n",
75
+ " [0.6725, 0.8287]])\n"
76
+ ]
77
+ }
78
+ ],
79
+ "source": [
80
+ "if 'linear.bias' in tensors:\n",
81
+ " W_out = torch.nn.Linear(in_features=384, out_features=1024, bias=True)\n",
82
+ " W_out.load_state_dict({\n",
83
+ " \"weight\": tensors[\"linear.weight\"], \n",
84
+ " \"bias\": tensors[\"linear.bias\"]\n",
85
+ " })\n",
86
+ "else:\n",
87
+ " W_out = torch.nn.Linear(in_features=384, out_features=1024, bias=False)\n",
88
+ " W_out.load_state_dict({\n",
89
+ " \"weight\": tensors[\"linear.weight\"]\n",
90
+ " })\n",
91
+ "\n",
92
+ "_ = model.eval()\n",
93
+ "_ = W_out.eval()\n",
94
+ "\n",
95
+ "# Example queries and documents \n",
96
+ "queries = [\n",
97
+ " \"What is machine learning?\", \n",
98
+ " \"How does neural network training work?\" \n",
99
+ "] \n",
100
+ " \n",
101
+ "documents = [ \n",
102
+ " \"Machine learning is a subset of artificial intelligence that focuses on algorithms that can learn from data.\", \n",
103
+ " \"Neural networks are trained through backpropagation, adjusting weights to minimize prediction errors.\" \n",
104
+ "]\n",
105
+ "\n",
106
+ "# Tokenize\n",
107
+ "QUERY_PREFIX = 'Represent this sentence for searching relevant passages: '\n",
108
+ "queries_with_prefix = [QUERY_PREFIX + query for query in queries]\n",
109
+ "\n",
110
+ "query_tokens = tokenizer(queries_with_prefix, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
111
+ "document_tokens = tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=512)\n",
112
+ "\n",
113
+ "# Perform Inference\n",
114
+ "with torch.inference_mode():\n",
115
+ " y_queries = model(**query_tokens).last_hidden_state\n",
116
+ " y_docs = model(**document_tokens).last_hidden_state\n",
117
+ "\n",
118
+ " # perform pooling\n",
119
+ " y_queries = y_queries * query_tokens.attention_mask.unsqueeze(-1)\n",
120
+ " y_queries_pooled = y_queries.sum(dim=1) / query_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
121
+ "\n",
122
+ " y_docs = y_docs * document_tokens.attention_mask.unsqueeze(-1)\n",
123
+ " y_docs_pooled = y_docs.sum(dim=1) / document_tokens.attention_mask.sum(dim=1, keepdim=True)\n",
124
+ "\n",
125
+ " # map to desired output dimension\n",
126
+ " query_embeddings = W_out(y_queries_pooled)\n",
127
+ " document_embeddings = W_out(y_docs_pooled)\n",
128
+ "\n",
129
+ "similarities = F.cosine_similarity(query_embeddings.unsqueeze(0), document_embeddings.unsqueeze(1), dim=-1).T\n",
130
+ "print(f\"Similarities:\\n{similarities}\")\n",
131
+ "\n",
132
+ "# Similarities:\n",
133
+ "# tensor([[0.6857, 0.4598],\n",
134
+ "# [0.4238, 0.5723]])"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "5a2b0244",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": []
144
+ }
145
+ ],
146
+ "metadata": {
147
+ "kernelspec": {
148
+ "display_name": "alexis",
149
+ "language": "python",
150
+ "name": "python3"
151
+ },
152
+ "language_info": {
153
+ "codemirror_mode": {
154
+ "name": "ipython",
155
+ "version": 3
156
+ },
157
+ "file_extension": ".py",
158
+ "mimetype": "text/x-python",
159
+ "name": "python",
160
+ "nbconvert_exporter": "python",
161
+ "pygments_lexer": "ipython3",
162
+ "version": "3.12.7"
163
+ }
164
+ },
165
+ "nbformat": 4,
166
+ "nbformat_minor": 5
167
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff