JedailabNetcheck commited on
Commit
3130237
·
1 Parent(s): 222534c
Files changed (4) hide show
  1. .gitattributes +2 -0
  2. README.md +5 -3
  3. app.py +364 -0
  4. requirements.txt +7 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data.nosync/spanish_news_2020_1M-sentences.txt filter=lfs diff=lfs merge=lfs -text
37
+ data.nosync/english_news_2020_1M-sentences.txt filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
  title: Mtem Pruner Spanish
3
- emoji: 🐨
4
  colorFrom: gray
5
- colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.43.2
8
  app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Mtem Pruner Spanish
3
+ emoji: ✂️
4
  colorFrom: gray
5
+ colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.43.2
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ short_description: Multilingual Text Embedding Model Pruner
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import csv
4
+ import json
5
+ import torch
6
+ import shutil
7
+ import tempfile
8
+ import textwrap
9
+ import numpy as np
10
+ import pandas as pd
11
+ import streamlit as st
12
+ from collections import Counter
13
+ from tokenizers import Tokenizer
14
+ import plotly.graph_objects as go
15
+ from huggingface_hub import whoami, HfApi, snapshot_download
16
+ from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, pipeline
17
+
18
+
19
+ LANGUAGES = {
20
+ "french": {"emoji":"🇫🇷", "nllb_code":"fra_Latn", "hf_code":"fr"},
21
+ "english": {"emoji":"🇬🇧", "nllb_code":"eng_Latn", "hf_code":"en"},
22
+ "german": {"emoji":"🇩🇪", "nllb_code":"deu_Latn", "hf_code":"de"},
23
+ "italian": {"emoji":"🇮🇹", "nllb_code":"ita_Latn", "hf_code":"it"},
24
+ "spanish": {"emoji":"🇪🇸", "nllb_code":"spa_Latn", "hf_code":"es"},
25
+ "portuguese": {"emoji":"🇵🇹", "nllb_code":"por_Latn", "hf_code":"pt"}
26
+ }
27
+
28
+ MODELS = [
29
+ "intfloat/multilingual-e5-small",
30
+ "intfloat/multilingual-e5-base",
31
+ "intfloat/multilingual-e5-large",
32
+ "BAAI/bge-m3",
33
+ "Alibaba-NLP/gte-multilingual-base",
34
+ #"jinaai/jina-embeddings-v3", # TODO: uses ParametrizedEmbedding
35
+ ]
36
+
37
+ def estimate_pruned_vocabulary(tokenizer: PreTrainedTokenizerFast, language: str):
38
+ """
39
+ Estimate the most common tokens in the language. You should first download the 1M sentences dataset
40
+ for the desired language. Source: https://wortschatz.uni-leipzig.de/en/download/English
41
+ """
42
+ sentences_file = f'data.nosync/{language}_news_2020_1M-sentences.txt'
43
+ if os.path.exists(sentences_file):
44
+ my_bar = st.progress(0)
45
+ df = pd.read_csv(sentences_file, sep='\t', header=None, quoting=csv.QUOTE_NONE, names=['id', 'text'])
46
+ counter = Counter(tokenizer.all_special_ids)
47
+ for i, text in enumerate(df.text):
48
+ counter.update(tokid for tokid in tokenizer.encode(text))
49
+ my_bar.progress(i/len(df), text=f"{i/len(df)*100:.0f}%")
50
+ filtered_token_ids = sorted(counter.keys())
51
+ filtered_tokens = tokenizer.convert_ids_to_tokens(filtered_token_ids)
52
+ return set(filtered_tokens)
53
+ else:
54
+ raise FileNotFoundError
55
+
56
+ @st.cache_resource
57
+ def load_model_and_tokenizer(model_name: str):
58
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
60
+ return model, tokenizer
61
+
62
+ def count_parameters(model, layer_name: str = None):
63
+ return sum(p.numel() for name, p in model.named_parameters() if layer_name is None or name.startswith(layer_name))
64
+
65
+ @st.cache_resource
66
+ def get_test_sentence(target_lang: str, source_lang: str = "eng_Latn"):
67
+ text = """
68
+ Alan Mathison Turing (23 June 1912 - 7 June 1954) was an English mathematician,
69
+ computer scientist, logician, cryptanalyst, philosopher and theoretical biologist.
70
+ """
71
+ if target_lang == "eng_Latn":
72
+ return text
73
+ model_name = "facebook/nllb-200-distilled-600M"
74
+ translator = pipeline(task="translation", tokenizer=model_name, model=model_name)
75
+ return translator(text, src_lang=source_lang, tgt_lang=target_lang)[0]['translation_text']
76
+
77
+ def push_to_hub(hf_username: str, hf_token: str, model_dir: str, private: bool = False):
78
+ api = HfApi(endpoint="https://huggingface.co", token=hf_token)
79
+ repo_id = f"{hf_username}/{model_dir.split('/')[-1]}"
80
+ api.create_repo(repo_id=repo_id, repo_type="model", private=private)
81
+ api.upload_folder(repo_id=repo_id, folder_path=model_dir, commit_message="Upload pruned model")
82
+
83
+ def prune_model(model_name: str, language: str, hf_username: str, hf_token: str, keep_english: bool):
84
+ st.markdown(f"- Let's prune the [**{model_name}**](https://huggingface.co/{model_name}) model to keep its **{language.capitalize()}** tokens only.")
85
+
86
+ # Load the model and its tokenizer
87
+ model, tokenizer = load_model_and_tokenizer(model_name)
88
+
89
+ # Calculate parameters for the original model
90
+ all_params = count_parameters(model)
91
+ encoder_params = count_parameters(model, layer_name="encoder")
92
+ embedding_params = count_parameters(model, layer_name="embeddings")
93
+
94
+ st.markdown(
95
+ f"- The original model has **{all_params/1e6:.1f}M** parameters, of which **{embedding_params/all_params*100:.0f}%** "+
96
+ f"(i.e., {embedding_params/1e6:.1f}M params) come from the *embedding matrix* and its {tokenizer.vocab_size} token entries. "+
97
+ f"This means that the contextualization of text sequences is actually done by a *{model.config.num_hidden_layers}-layer Transformer encoder* "+
98
+ f"with **{encoder_params/1e6:.1f}M** parameters only."
99
+ )
100
+
101
+ with st.status(f"Computing the {language.capitalize()} vocabulary...", expanded=True) as status:
102
+ filtered_tokens = estimate_pruned_vocabulary(tokenizer, language)
103
+ num_filtered_tokens = len(filtered_tokens)
104
+ st.write(
105
+ f"{language.capitalize()} only uses **{num_filtered_tokens/tokenizer.vocab_size*100:.0f}%** "+
106
+ f"of the model vocabulary (i.e., {num_filtered_tokens} out of the original {tokenizer.vocab_size} tokens)."
107
+ )
108
+ status.update(state="complete", expanded=True)
109
+
110
+ if keep_english:
111
+ with st.status(f"Computing the English vocabulary...", expanded=True) as status:
112
+ english_tokens = estimate_pruned_vocabulary(tokenizer, "english")
113
+ filtered_tokens.update(english_tokens)
114
+ st.write(f"Considering the **English** tokens adds **{len(filtered_tokens) - num_filtered_tokens}** tokens to the vocabulary.")
115
+ num_filtered_tokens = len(filtered_tokens)
116
+ status.update(state="complete", expanded=True)
117
+
118
+ with st.status("Pruning the model...", expanded=True) as status:
119
+ st.write("- *Updating the tokenizer*")
120
+ outdir = f"{language}-{model_name.split('/')[-1]}"
121
+
122
+ # Export the tokenizer to a JSON string and access its vocabulary (list of lists: [[token, score], ...])
123
+ tokenizer_json = json.loads(tokenizer.backend_tokenizer.to_str())
124
+ original_vocab = tokenizer_json['model']['vocab']
125
+ original_token_to_id = {entry[0]: idx for idx, entry in enumerate(original_vocab)}
126
+
127
+ # Filter out the tokens to remove and reassign new IDs
128
+ new_id = 0
129
+ new_token_to_id = {}
130
+ new_id_to_original_id = {}
131
+ filtered_vocab_entries = []
132
+
133
+ for token, score in original_vocab:
134
+ if token in filtered_tokens:
135
+ filtered_vocab_entries.append([token, score])
136
+ new_token_to_id[token] = new_id
137
+ new_id_to_original_id[new_id] = original_token_to_id[token]
138
+ new_id += 1
139
+
140
+ # Update the vocab in the tokenizer JSON and rebuild the tokenizer from the modified JSON
141
+ tokenizer_json['model']['vocab'] = filtered_vocab_entries
142
+ new_backend_tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))
143
+
144
+ # Create a new tokenizer instance and save it
145
+ new_tokenizer = PreTrainedTokenizerFast(tokenizer_object=new_backend_tokenizer, **tokenizer.init_kwargs)
146
+ new_tokenizer.save_pretrained(outdir)
147
+
148
+ st.write("- *Updating the embedding matrix*")
149
+ new_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
150
+
151
+ # Create a new embedding matrix and map the original vectors to their new IDs
152
+ original_embeddings = new_model.get_input_embeddings().weight.data
153
+ new_embeddings = torch.nn.Embedding(
154
+ num_embeddings=new_tokenizer.vocab_size,
155
+ embedding_dim=model.config.hidden_size,
156
+ padding_idx=new_tokenizer.pad_token_id,
157
+ )
158
+
159
+ for new_id in range(new_tokenizer.vocab_size):
160
+ original_id = new_id_to_original_id.get(new_id)
161
+ new_embeddings.weight.data[new_id] = original_embeddings[original_id]
162
+
163
+ new_model.set_input_embeddings(new_embeddings)
164
+ new_model.config.vocab_size = new_tokenizer.vocab_size
165
+ new_model.save_pretrained(outdir)
166
+
167
+ status.update(state="complete", expanded=True)
168
+
169
+ with st.status("Testing the conversion...", expanded=True) as status:
170
+ st.write(f"- *Checking the pruned tokenizer*")
171
+ assert len(new_tokenizer) == num_filtered_tokens, f"ERROR: new tokenizer size ({len(new_tokenizer)}) != number of filtered tokens ({num_filtered_tokens})"
172
+ assert filtered_tokens == set(new_tokenizer.convert_ids_to_tokens(range(len(new_tokenizer)))), f"ERROR: The new tokenizer vocabulary doesn't match number of the filtered tokens"
173
+
174
+ st.write(f"- *Checking the pruned model*")
175
+ test_sentence = get_test_sentence(LANGUAGES[language]['nllb_code'])
176
+ with torch.inference_mode():
177
+ emb1 = model(**tokenizer(test_sentence, return_tensors='pt')).last_hidden_state[:, 0][0].numpy()
178
+ emb2 = new_model(**new_tokenizer(test_sentence, return_tensors='pt')).last_hidden_state[:, 0][0].numpy()
179
+ diff = np.abs(emb1 - emb2).max()
180
+ assert diff < 1e-6, f"ERROR: Some dimensions of the two vectors have a non negligible difference ({diff})"
181
+
182
+ st.write(f"""All good! The output *[cls]* token embedding of the test sentence *"{test_sentence}"* should be similar:""")
183
+ col1, col2 = st.columns(2)
184
+ with col1:
185
+ st.markdown("Original model:")
186
+ st.code(f"{emb1.tolist()}")
187
+ with col2:
188
+ st.markdown("Pruned model:")
189
+ st.code(f"{emb2.tolist()}")
190
+
191
+ status.update(state="complete", expanded=True)
192
+
193
+ # Show visually the result of the pruning process
194
+ pruned_all_params = count_parameters(new_model)
195
+ pruned_encoder_params = count_parameters(new_model, layer_name="encoder")
196
+ pruned_embedding_params = count_parameters(new_model, layer_name="embeddings")
197
+ st.markdown(f"The pruned model is **{pruned_all_params/all_params*100:.1f}%** of the original model size.")
198
+ data = {
199
+ 'Model': ['Original', 'Pruned'],
200
+ 'Embedding': [embedding_params / 1e6, pruned_embedding_params / 1e6],
201
+ 'Encoder': [encoder_params / 1e6, pruned_encoder_params / 1e6]
202
+ }
203
+ fig = go.Figure(data=[
204
+ go.Bar(name='Embedding matrix', x=data['Model'], y=data['Embedding'], text=data['Embedding'], textposition='inside', marker_color='#E5B4B4'),
205
+ go.Bar(name='Transformer encoder', x=data['Model'], y=data['Encoder'], text=data['Encoder'], textposition='inside', marker_color='#7FBFE0')
206
+ ])
207
+ fig.update_layout(barmode='stack', yaxis_title='# Params (M)', height=400, margin=dict(t=10, b=10))
208
+ fig.update_traces(texttemplate='%{text:.1f}M', textposition='inside', insidetextanchor='middle')
209
+ st.plotly_chart(fig)
210
+
211
+ with st.status("Pushing the pruned model to your Hugging Face account...", expanded=True) as status:
212
+ st.write("- *Adding sentence-transformers files*")
213
+ with tempfile.TemporaryDirectory() as tmpdirname:
214
+ snapshot_download(repo_id=model_name, local_dir=tmpdirname, token=hf_token)
215
+
216
+ src_modules_json = os.path.join(tmpdirname, "modules.json")
217
+ if os.path.exists(src_modules_json):
218
+ shutil.copy2(src_modules_json, os.path.join(outdir, "modules.json"))
219
+
220
+ src_sentence_bert_config = os.path.join(tmpdirname, "sentence_bert_config.json")
221
+ if os.path.exists(src_sentence_bert_config):
222
+ shutil.copy2(src_sentence_bert_config, os.path.join(outdir, "sentence_bert_config.json"))
223
+
224
+ src_pooling_folder = os.path.join(tmpdirname, "1_Pooling")
225
+ if os.path.exists(src_pooling_folder):
226
+ shutil.copytree(src_pooling_folder, os.path.join(outdir, "1_Pooling"), dirs_exist_ok=True)
227
+
228
+ src_readme = os.path.join(tmpdirname, "README.md")
229
+ if os.path.exists(src_readme):
230
+ with open(src_readme, 'r', encoding='utf-8') as file:
231
+ content = file.read()
232
+ match = re.search(r'license:\s*(\S+)', content, re.IGNORECASE)
233
+ if match:
234
+ original_license = match.group(1)
235
+
236
+ st.write("- *Adding a README*")
237
+ new_model_name = f"{hf_username}/{outdir.split('/')[-1]}"
238
+ readme_content = textwrap.dedent(f"""
239
+ ---
240
+ pipeline_tag: sentence-similarity
241
+ language: {LANGUAGES[language]['hf_code']}
242
+ license: {original_license}
243
+ tags:
244
+ - passage-retrieval
245
+ - sentence-similarity
246
+ - pruned
247
+ library_name: sentence-transformers
248
+ base_model: {model_name}
249
+ base_model_relation: quantized
250
+ ---
251
+ # {LANGUAGES[language]['emoji']} {new_model_name.split('/')[-1]}
252
+
253
+ This model is a {100 - pruned_all_params/all_params*100:.1f}% smaller version of [{model_name}](https://huggingface.co/{model_name})
254
+ for the {language.capitalize()} language, created using the [mtem-pruner](https://huggingface.co/spaces/antoinelouis/mtem-pruner) space.
255
+
256
+ This pruned model should perform similarly to the original model for {language.capitalize()} language tasks with a much smaller
257
+ memory footprint. However, it may not perform well for other languages present in the original multilingual model as tokens not
258
+ commonly used in {language.capitalize()} were removed from the original multilingual model's vocabulary.
259
+
260
+ ## Usage
261
+
262
+ You can use this model with the Transformers library:
263
+
264
+ ```python
265
+ from transformers import AutoModel, AutoTokenizer
266
+
267
+ model_name = "{new_model_name}"
268
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
269
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
270
+ ```
271
+
272
+ Or with the sentence-transformers library:
273
+
274
+ ```python
275
+ from sentence_transformers import SentenceTransformer
276
+
277
+ model = SentenceTransformer("{new_model_name}")
278
+ ```
279
+
280
+ **Credits**: cc [@antoinelouis](https://huggingface.co/antoinelouis)
281
+ """)
282
+ with open(os.path.join(outdir, "README.md"), "w") as f:
283
+ f.write(readme_content)
284
+
285
+ st.write("- *Pushing to Hub*")
286
+ push_to_hub(hf_username, hf_token, outdir)
287
+
288
+ shutil.rmtree(outdir)
289
+ status.update(state="complete", expanded=False)
290
+
291
+ st.markdown("Done! You can now load your pruned model like this:")
292
+ st.code(f"""
293
+ from transformers import AutoModel, AutoTokenizer
294
+
295
+ model_name = "{new_model_name}"
296
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
297
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=True)
298
+ """, language="python")
299
+
300
+
301
+ def main():
302
+ st.header("Multilingual Text Embedding Model Pruner")
303
+ st.markdown("""
304
+ This space helps you create a smaller, language-specific version of a multilingual text embedding model. Here's what it does:
305
+
306
+ 1. 🌎 Takes a state-of-the-art text embedding model that was trained on many languages
307
+ 2. ✂️ Trims it down to focus on just one language by removing unused tokens from its vocabulary
308
+ 3. 🚀 Gives you a smaller model that works just as well for your chosen language
309
+
310
+ #### Why is this useful?
311
+
312
+ - 💾 Get the same performance in your language with a much smaller model size
313
+ - 🌐 Great for low-resource environments with limited RAM
314
+
315
+ Ready to shrink your model? Let's get started!
316
+ """)
317
+
318
+ model_name = st.selectbox("Choose a multilingual model", MODELS)
319
+
320
+ col1, col2 = st.columns([3, 1])
321
+ with col1:
322
+ language = st.selectbox(
323
+ "Pick your target language",
324
+ options=list(LANGUAGES.keys()),
325
+ format_func=lambda x: f"{LANGUAGES[x]['emoji']} {x.capitalize()}"
326
+ )
327
+ with col2:
328
+ st.write("")
329
+ st.write("")
330
+ keep_english = st.checkbox("Keep English", value=False, help="Keep English tokens in addition to the selected language")
331
+
332
+ col3, col4 = st.columns(2)
333
+ with col3:
334
+ hf_username = st.text_input("Your Hugging Face username", placeholder="antoinelouis")
335
+ with col4:
336
+ hf_token = st.text_input("Your Hugging Face access token", type="password", placeholder="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
337
+
338
+ if st.button("Prune model"):
339
+ if not hf_username or not hf_token:
340
+ st.error("Your HF username and access token are required to save the pruned model on your account.")
341
+ else:
342
+ _ = whoami(token=hf_token)
343
+ prune_model(model_name, language, hf_username, hf_token, keep_english)
344
+
345
+ st.markdown(
346
+ """
347
+ <style>
348
+ .credits {
349
+ position: fixed;
350
+ right: 10px;
351
+ bottom: 10px;
352
+ color: #888888;
353
+ font-size: 11px;
354
+ }
355
+ </style>
356
+ <div class="credits">
357
+ Credits to <a href="https://gist.github.com/avidale/44cd35bfcdaf8bedf51d97c468cc8001" target="_blank">@avidale</a> for inspiration.
358
+ </div>
359
+ """,
360
+ unsafe_allow_html=True
361
+ )
362
+
363
+ if __name__ == "__main__":
364
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ pandas
4
+ numpy<2
5
+ einops
6
+ streamlit
7
+ plotly