tuandunghcmut commited on
Commit
42c85ac
·
verified ·
1 Parent(s): d2ab3a0

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. DeepSeek-VL2/.ipynb_checkpoints/README-checkpoint.md +397 -0
  2. DeepSeek-VL2/.ipynb_checkpoints/inference-checkpoint.py +185 -0
  3. DeepSeek-VL2/.ipynb_checkpoints/requirements-checkpoint.txt +19 -0
  4. DeepSeek-VL2/deepseek_vl2.egg-info/requires.txt +33 -0
  5. DeepSeek-VL2/deepseek_vl2.egg-info/top_level.txt +1 -0
  6. DeepSeek-VL2/deepseek_vl2/__init__.py +31 -0
  7. DeepSeek-VL2/deepseek_vl2/__pycache__/__init__.cpython-310.pyc +0 -0
  8. DeepSeek-VL2/deepseek_vl2/__pycache__/__init__.cpython-39.pyc +0 -0
  9. DeepSeek-VL2/deepseek_vl2/models/__pycache__/__init__.cpython-310.pyc +0 -0
  10. DeepSeek-VL2/deepseek_vl2/models/__pycache__/__init__.cpython-311.pyc +0 -0
  11. DeepSeek-VL2/deepseek_vl2/models/__pycache__/configuration_deepseek.cpython-39.pyc +0 -0
  12. DeepSeek-VL2/deepseek_vl2/models/__pycache__/conversation.cpython-311.pyc +0 -0
  13. DeepSeek-VL2/deepseek_vl2/models/__pycache__/conversation.cpython-312.pyc +0 -0
  14. DeepSeek-VL2/deepseek_vl2/models/__pycache__/conversation.cpython-39.pyc +0 -0
  15. DeepSeek-VL2/deepseek_vl2/models/__pycache__/modeling_deepseek_vl_v2.cpython-311.pyc +0 -0
  16. DeepSeek-VL2/deepseek_vl2/models/__pycache__/modeling_deepseek_vl_v2.cpython-312.pyc +0 -0
  17. DeepSeek-VL2/deepseek_vl2/models/__pycache__/modeling_deepseek_vl_v2.cpython-39.pyc +0 -0
  18. DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-310.pyc +0 -0
  19. DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-311.pyc +0 -0
  20. DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-312.pyc +0 -0
  21. DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-39.pyc +0 -0
  22. DeepSeek-VL2/deepseek_vl2/models/__pycache__/siglip_vit.cpython-310.pyc +0 -0
  23. DeepSeek-VL2/deepseek_vl2/models/__pycache__/siglip_vit.cpython-312.pyc +0 -0
  24. DeepSeek-VL2/deepseek_vl2/models/__pycache__/siglip_vit.cpython-39.pyc +0 -0
  25. DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/gradio_utils.cpython-312.pyc +0 -0
  26. DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/utils.cpython-312.pyc +0 -0
  27. DeepSeek-VL2/deepseek_vl2/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  28. DeepSeek-VL2/deepseek_vl2/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  29. DeepSeek-VL2/deepseek_vl2/utils/__pycache__/io.cpython-312.pyc +0 -0
  30. DeepSeek-VL2/deepseek_vl2/utils/__pycache__/io.cpython-39.pyc +0 -0
  31. DeepSeek-VL2/images/logo.png +0 -0
  32. DeepSeek-VL2/images/logo.svg +22 -0
  33. DeepSeek-VL2/images/monday.jpg +0 -0
  34. DeepSeek-VL2/images/visual_grounding_2.jpg +0 -0
  35. DeepSeek-VL2/images/vl2_teaser.jpeg +0 -0
  36. VLM2Vec/archive/gather_score_byckpt_aws.py +132 -0
  37. VLM2Vec/archive/merge.py +26 -0
  38. VLM2Vec/archive/testset_stats.py +66 -0
  39. VLM2Vec/evaluation/eval_flickr.py +124 -0
  40. VLM2Vec/figures/example.jpg +0 -0
  41. VLM2Vec/grad_cache/__init__.py +4 -0
  42. VLM2Vec/grad_cache/cachex/__init__.py +3 -0
  43. VLM2Vec/grad_cache/context_managers.py +21 -0
  44. VLM2Vec/grad_cache/functional.py +91 -0
  45. VLM2Vec/grad_cache/grad_cache.py +279 -0
  46. VLM2Vec/grad_cache/loss.py +80 -0
  47. VLM2Vec/grad_cache/minigc_cmd.md +90 -0
  48. VLM2Vec/scripts/llava_next/demo.py +46 -0
  49. VLM2Vec/scripts/llava_next/run_eval_flickr_llava_next.sh +21 -0
  50. VLM2Vec/src/arguments.py +121 -0
DeepSeek-VL2/.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- markdownlint-disable first-line-h1 -->
2
+ <!-- markdownlint-disable html -->
3
+ <!-- markdownlint-disable no-duplicate-header -->
4
+
5
+ <div align="center">
6
+ <img src="images/logo.svg" width="60%" alt="DeepSeek LLM" />
7
+ </div>
8
+ <hr>
9
+ <div align="center">
10
+
11
+ <a href="https://www.deepseek.com/" target="_blank">
12
+ <img alt="Homepage" src="images/badge.svg" />
13
+ </a>
14
+ <a href="" target="_blank">
15
+ <img alt="Chat" src="https://img.shields.io/badge/🤖%20Chat-DeepSeek%20VL-536af5?color=536af5&logoColor=white" />
16
+ </a>
17
+ <a href="https://huggingface.co/deepseek-ai" target="_blank">
18
+ <img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" />
19
+ </a>
20
+
21
+ </div>
22
+
23
+
24
+ <div align="center">
25
+
26
+ <a href="https://discord.gg/Tc7c45Zzu5" target="_blank">
27
+ <img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" />
28
+ </a>
29
+ <a href="images/qr.jpeg" target="_blank">
30
+ <img alt="Wechat" src="https://img.shields.io/badge/WeChat-DeepSeek%20AI-brightgreen?logo=wechat&logoColor=white" />
31
+ </a>
32
+ <a href="https://twitter.com/deepseek_ai" target="_blank">
33
+ <img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" />
34
+ </a>
35
+
36
+ </div>
37
+
38
+ <div align="center">
39
+
40
+ <a href="LICENSE-CODE">
41
+ <img alt="Code License" src="https://img.shields.io/badge/Code_License-MIT-f5de53?&color=f5de53">
42
+ </a>
43
+ <a href="LICENSE-MODEL">
44
+ <img alt="Model License" src="https://img.shields.io/badge/Model_License-Model_Agreement-f5de53?&color=f5de53">
45
+ </a>
46
+ </div>
47
+
48
+
49
+ <p align="center">
50
+ <a href="https://github.com/deepseek-ai/DeepSeek-VL2/tree/main?tab=readme-ov-file#3-model-download"><b>📥 Model Download</b></a> |
51
+ <a href="https://github.com/deepseek-ai/DeepSeek-VL2/tree/main?tab=readme-ov-file#4-quick-start"><b>⚡ Quick Start</b></a> |
52
+ <a href="https://github.com/deepseek-ai/DeepSeek-VL2/tree/main?tab=readme-ov-file#5-license"><b>📜 License</b></a> |
53
+ <a href="https://github.com/deepseek-ai/DeepSeek-VL2/tree/main?tab=readme-ov-file#6-citation"><b>📖 Citation</b></a> <br>
54
+ <a href="./DeepSeek_VL2_paper.pdf"><b>📄 Paper Link</b></a> |
55
+ <a href="https://arxiv.org/abs/2412.10302"><b>📄 Arxiv Paper Link</b></a> |
56
+ <a href=""><b>👁️ Demo</b></a>
57
+ </p>
58
+
59
+ ## 1. Introduction
60
+
61
+ Introducing DeepSeek-VL2, an advanced series of large Mixture-of-Experts (MoE) Vision-Language Models that significantly improves upon its predecessor, DeepSeek-VL. DeepSeek-VL2 demonstrates superior capabilities across various tasks, including but not limited to visual question answering, optical character recognition, document/table/chart understanding, and visual grounding. Our model series is composed of three variants: DeepSeek-VL2-Tiny, DeepSeek-VL2-Small and DeepSeek-VL2, with 1.0B, 2.8B and 4.5B activated parameters respectively.
62
+ DeepSeek-VL2 achieves competitive or state-of-the-art performance with similar or fewer activated parameters compared to existing open-source dense and MoE-based models.
63
+
64
+
65
+ [DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding]()
66
+
67
+ Zhiyu Wu*, Xiaokang Chen*, Zizheng Pan*, Xingchao Liu*, Wen Liu**, Damai Dai, Huazuo Gao, Yiyang Ma, Chengyue Wu, Bingxuan Wang, Zhenda Xie, Yu Wu, Kai Hu, Jiawei Wang, Yaofeng Sun, Yukun Li, Yishi Piao, Kang Guan, Aixin Liu, Xin Xie, Yuxiang You, Kai Dong, Xingkai Yu, Haowei Zhang, Liang Zhao, Yisong Wang, Chong Ruan*** (* Equal Contribution, ** Project Lead, *** Corresponding author)
68
+
69
+ ![](./images/vl2_teaser.jpeg)
70
+
71
+ ## 2. Release
72
+ ✅ <b>2024-12-25</b>: Gradio Demo Example, Incremental Prefilling and VLMEvalKit Support.
73
+
74
+ ✅ <b>2024-12-13</b>: DeepSeek-VL2 family released, including <code>DeepSeek-VL2-tiny</code>, <code>DeepSeek-VL2-small</code>, <code>DeepSeek-VL2</code>.
75
+
76
+ ## 3. Model Download
77
+
78
+ We release the DeepSeek-VL2 family, including <code>DeepSeek-VL2-tiny</code>, <code>DeepSeek-VL2-small</code>, <code>DeepSeek-VL2</code>.
79
+ To support a broader and more diverse range of research within both academic and commercial communities.
80
+ Please note that the use of this model is subject to the terms outlined in [License section](#5-license).
81
+
82
+ ### Huggingface
83
+
84
+ | Model | Sequence Length | Download |
85
+ |--------------|-----------------|-----------------------------------------------------------------------------|
86
+ | DeepSeek-VL2-tiny | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl2-tiny) |
87
+ | DeepSeek-VL2-small | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl2-small) |
88
+ | DeepSeek-VL2 | 4096 | [🤗 Hugging Face](https://huggingface.co/deepseek-ai/deepseek-vl2) |
89
+
90
+
91
+ ## 4. Quick Start
92
+
93
+ ### Installation
94
+
95
+ On the basis of `Python >= 3.8` environment, install the necessary dependencies by running the following command:
96
+
97
+ ```shell
98
+ pip install -e .
99
+ ```
100
+
101
+ ### Simple Inference Example with One Image
102
+
103
+ **Note: You may need 80GB GPU memory to run this script with deepseek-vl2-small and even larger for deepseek-vl2.**
104
+
105
+ ```python
106
+ import torch
107
+ from transformers import AutoModelForCausalLM
108
+
109
+ from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
110
+ from deepseek_vl2.utils.io import load_pil_images
111
+
112
+
113
+ # specify the path to the model
114
+ model_path = "deepseek-ai/deepseek-vl2-tiny"
115
+ vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
116
+ tokenizer = vl_chat_processor.tokenizer
117
+
118
+ vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
119
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
120
+
121
+ ## single image conversation example
122
+ conversation = [
123
+ {
124
+ "role": "<|User|>",
125
+ "content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
126
+ "images": ["./images/visual_grounding_1.jpeg"],
127
+ },
128
+ {"role": "<|Assistant|>", "content": ""},
129
+ ]
130
+
131
+ # load images and prepare for inputs
132
+ pil_images = load_pil_images(conversation)
133
+ prepare_inputs = vl_chat_processor(
134
+ conversations=conversation,
135
+ images=pil_images,
136
+ force_batchify=True,
137
+ system_prompt=""
138
+ ).to(vl_gpt.device)
139
+
140
+ # run image encoder to get the image embeddings
141
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
142
+
143
+ # run the model to get the response
144
+ outputs = vl_gpt.language.generate(
145
+ inputs_embeds=inputs_embeds,
146
+ attention_mask=prepare_inputs.attention_mask,
147
+ pad_token_id=tokenizer.eos_token_id,
148
+ bos_token_id=tokenizer.bos_token_id,
149
+ eos_token_id=tokenizer.eos_token_id,
150
+ max_new_tokens=512,
151
+ do_sample=False,
152
+ use_cache=True
153
+ )
154
+
155
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
156
+ print(f"{prepare_inputs['sft_format'][0]}", answer)
157
+ ```
158
+
159
+ And the output is something like:
160
+ ```
161
+ <|User|>: <image>
162
+ <|ref|>The giraffe at the back.<|/ref|>.
163
+
164
+ <|Assistant|>: <|ref|>The giraffe at the back.<|/ref|><|det|>[[580, 270, 999, 900]]<|/det|><|end▁of▁sentence|>
165
+ ```
166
+
167
+ ### Simple Inference Example with Multiple Images
168
+
169
+ **Note: You may need 80GB GPU memory to run this script with deepseek-vl2-small and even larger for deepseek-vl2.**
170
+
171
+ ```python
172
+ import torch
173
+ from transformers import AutoModelForCausalLM
174
+
175
+ from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
176
+ from deepseek_vl2.utils.io import load_pil_images
177
+
178
+
179
+ # specify the path to the model
180
+ model_path = "deepseek-ai/deepseek-vl2-tiny"
181
+ vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
182
+ tokenizer = vl_chat_processor.tokenizer
183
+
184
+ vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
185
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
186
+
187
+ # multiple images/interleaved image-text
188
+ conversation = [
189
+ {
190
+ "role": "<|User|>",
191
+ "content": "This is image_1: <image>\n"
192
+ "This is image_2: <image>\n"
193
+ "This is image_3: <image>\n Can you tell me what are in the images?",
194
+ "images": [
195
+ "images/multi_image_1.jpeg",
196
+ "images/multi_image_2.jpeg",
197
+ "images/multi_image_3.jpeg",
198
+ ],
199
+ },
200
+ {"role": "<|Assistant|>", "content": ""}
201
+ ]
202
+
203
+ # load images and prepare for inputs
204
+ pil_images = load_pil_images(conversation)
205
+ prepare_inputs = vl_chat_processor(
206
+ conversations=conversation,
207
+ images=pil_images,
208
+ force_batchify=True,
209
+ system_prompt=""
210
+ ).to(vl_gpt.device)
211
+
212
+ # run image encoder to get the image embeddings
213
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
214
+
215
+ # run the model to get the response
216
+ outputs = vl_gpt.language.generate(
217
+ inputs_embeds=inputs_embeds,
218
+ attention_mask=prepare_inputs.attention_mask,
219
+ pad_token_id=tokenizer.eos_token_id,
220
+ bos_token_id=tokenizer.bos_token_id,
221
+ eos_token_id=tokenizer.eos_token_id,
222
+ max_new_tokens=512,
223
+ do_sample=False,
224
+ use_cache=True
225
+ )
226
+
227
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=False)
228
+ print(f"{prepare_inputs['sft_format'][0]}", answer)
229
+ ```
230
+
231
+ And the output is something like:
232
+ ```
233
+ <|User|>: This is image_1: <image>
234
+ This is image_2: <image>
235
+ This is image_3: <image>
236
+ Can you tell me what are in the images?
237
+
238
+ <|Assistant|>: The images show three different types of vegetables. Image_1 features carrots, which are orange with green tops. Image_2 displays corn cobs, which are yellow with green husks. Image_3 contains raw pork ribs, which are pinkish-red with some marbling.<|end▁of▁sentence|>
239
+ ```
240
+
241
+ ### Simple Inference Example with Incremental Prefilling
242
+
243
+ **Note: We use incremental prefilling to inference within 40GB GPU using deepseek-vl2-small.**
244
+
245
+ ```python
246
+ import torch
247
+ from transformers import AutoModelForCausalLM
248
+
249
+ from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
250
+ from deepseek_vl2.utils.io import load_pil_images
251
+
252
+
253
+ # specify the path to the model
254
+ model_path = "deepseek-ai/deepseek-vl2-small"
255
+ vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
256
+ tokenizer = vl_chat_processor.tokenizer
257
+
258
+ vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
259
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
260
+
261
+ # multiple images/interleaved image-text
262
+ conversation = [
263
+ {
264
+ "role": "<|User|>",
265
+ "content": "This is image_1: <image>\n"
266
+ "This is image_2: <image>\n"
267
+ "This is image_3: <image>\n Can you tell me what are in the images?",
268
+ "images": [
269
+ "images/multi_image_1.jpeg",
270
+ "images/multi_image_2.jpeg",
271
+ "images/multi_image_3.jpeg",
272
+ ],
273
+ },
274
+ {"role": "<|Assistant|>", "content": ""}
275
+ ]
276
+
277
+ # load images and prepare for inputs
278
+ pil_images = load_pil_images(conversation)
279
+ prepare_inputs = vl_chat_processor(
280
+ conversations=conversation,
281
+ images=pil_images,
282
+ force_batchify=True,
283
+ system_prompt=""
284
+ ).to(vl_gpt.device)
285
+
286
+ with torch.no_grad():
287
+ # run image encoder to get the image embeddings
288
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
289
+
290
+ # incremental_prefilling when using 40G GPU for vl2-small
291
+ inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
292
+ input_ids=prepare_inputs.input_ids,
293
+ images=prepare_inputs.images,
294
+ images_seq_mask=prepare_inputs.images_seq_mask,
295
+ images_spatial_crop=prepare_inputs.images_spatial_crop,
296
+ attention_mask=prepare_inputs.attention_mask,
297
+ chunk_size=512 # prefilling size
298
+ )
299
+
300
+ # run the model to get the response
301
+ outputs = vl_gpt.generate(
302
+ inputs_embeds=inputs_embeds,
303
+ input_ids=prepare_inputs.input_ids,
304
+ images=prepare_inputs.images,
305
+ images_seq_mask=prepare_inputs.images_seq_mask,
306
+ images_spatial_crop=prepare_inputs.images_spatial_crop,
307
+ attention_mask=prepare_inputs.attention_mask,
308
+ past_key_values=past_key_values,
309
+
310
+ pad_token_id=tokenizer.eos_token_id,
311
+ bos_token_id=tokenizer.bos_token_id,
312
+ eos_token_id=tokenizer.eos_token_id,
313
+ max_new_tokens=512,
314
+
315
+ do_sample=False,
316
+ use_cache=True,
317
+ )
318
+
319
+ answer = tokenizer.decode(outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), skip_special_tokens=False)
320
+
321
+ print(f"{prepare_inputs['sft_format'][0]}", answer)
322
+ ```
323
+
324
+ And the output is something like:
325
+ ```
326
+ <|User|>: This is image_1: <image>
327
+ This is image_2: <image>
328
+ This is image_3: <image>
329
+ Can you tell me what are in the images?
330
+
331
+ <|Assistant|>: The first image contains carrots. The second image contains corn. The third image contains meat.<|end▁of▁sentence|>
332
+ ```
333
+
334
+ ### Full Inference Example
335
+ ```shell
336
+ # without incremental prefilling
337
+ CUDA_VISIBLE_DEVICES=0 python inference.py --model_patn "deepseek-ai/deepseek-vl2"
338
+
339
+ # with incremental prefilling, when using 40G GPU for vl2-small
340
+ CUDA_VISIBLE_DEVICES=0 python inference.py --model_patn "deepseek-ai/deepseek-vl2-small" --chunck_size 512
341
+
342
+ ```
343
+
344
+
345
+ ### Gradio Demo
346
+
347
+ * Install the necessary dependencies:
348
+ ```shell
349
+ pip install -e .[gradio]
350
+ ```
351
+
352
+ * then run the following command:
353
+
354
+ ```shell
355
+ # vl2-tiny, 3.37B-MoE in total, activated 1B, can be run on a single GPU < 40GB
356
+ CUDA_VISIBLE_DEVICES=2 python web_demo.py \
357
+ --model_name "deepseek-ai/deepseek-vl2-tiny" \
358
+ --port 37914
359
+
360
+
361
+ # vl2-small, 16.1B-MoE in total, activated 2.4B
362
+ # If run on A100 40GB GPU, you need to set the `--chunk_size 512` for incremental prefilling for saving memory and it might be slow.
363
+ # If run on > 40GB GPU, you can ignore the `--chunk_size 512` for faster response.
364
+ CUDA_VISIBLE_DEVICES=2 python web_demo.py \
365
+ --model_name "deepseek-ai/deepseek-vl2-small" \
366
+ --port 37914 \
367
+ --chunk_size 512
368
+
369
+ # # vl27.5-MoE in total, activated 4.2B
370
+ CUDA_VISIBLE_DEVICES=2 python web_demo.py \
371
+ --model_name "deepseek-ai/deepseek-vl2" \
372
+ --port 37914
373
+ ```
374
+
375
+ * **Important**: This is a basic and native demo implementation without any deployment optimizations, which may result in slower performance. For production environments, consider using optimized deployment solutions, such as vllm, sglang, lmdeploy, etc. These optimizations will help achieve faster response times and better cost efficiency.
376
+
377
+ ## 5. License
378
+
379
+ This code repository is licensed under [MIT License](./LICENSE-CODE). The use of DeepSeek-VL2 models is subject to [DeepSeek Model License](./LICENSE-MODEL). DeepSeek-VL2 series supports commercial use.
380
+
381
+ ## 6. Citation
382
+
383
+ ```
384
+ @misc{wu2024deepseekvl2mixtureofexpertsvisionlanguagemodels,
385
+ title={DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding},
386
+ author={Zhiyu Wu and Xiaokang Chen and Zizheng Pan and Xingchao Liu and Wen Liu and Damai Dai and Huazuo Gao and Yiyang Ma and Chengyue Wu and Bingxuan Wang and Zhenda Xie and Yu Wu and Kai Hu and Jiawei Wang and Yaofeng Sun and Yukun Li and Yishi Piao and Kang Guan and Aixin Liu and Xin Xie and Yuxiang You and Kai Dong and Xingkai Yu and Haowei Zhang and Liang Zhao and Yisong Wang and Chong Ruan},
387
+ year={2024},
388
+ eprint={2412.10302},
389
+ archivePrefix={arXiv},
390
+ primaryClass={cs.CV},
391
+ url={https://arxiv.org/abs/2412.10302},
392
+ }
393
+ ```
394
+
395
+ ## 7. Contact
396
+
397
+ If you have any questions, please raise an issue or contact us at [[email protected]](mailto:[email protected]).
DeepSeek-VL2/.ipynb_checkpoints/inference-checkpoint.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+ from argparse import ArgumentParser
21
+ from typing import List, Dict
22
+ import torch
23
+ from transformers import AutoModelForCausalLM
24
+ import PIL.Image
25
+
26
+ from deepseek_vl2.models import DeepseekVLV2ForCausalLM, DeepseekVLV2Processor
27
+ from deepseek_vl2.serve.app_modules.utils import parse_ref_bbox
28
+
29
+
30
+ def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]:
31
+ """
32
+
33
+ Args:
34
+ conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
35
+ [
36
+ {
37
+ "role": "User",
38
+ "content": "<image>\nExtract all information from this image and convert them into markdown format.",
39
+ "images": ["./examples/table_datasets.png"]
40
+ },
41
+ {"role": "Assistant", "content": ""},
42
+ ]
43
+
44
+ Returns:
45
+ pil_images (List[PIL.Image.Image]): the list of PIL images.
46
+
47
+ """
48
+
49
+ pil_images = []
50
+
51
+ for message in conversations:
52
+ if "images" not in message:
53
+ continue
54
+
55
+ for image_path in message["images"]:
56
+ pil_img = PIL.Image.open(image_path)
57
+ pil_img = pil_img.convert("RGB")
58
+ pil_images.append(pil_img)
59
+
60
+ return pil_images
61
+
62
+
63
+ def main(args):
64
+
65
+ dtype = torch.bfloat16
66
+
67
+ # specify the path to the model
68
+ model_path = args.model_path
69
+ vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(model_path)
70
+ tokenizer = vl_chat_processor.tokenizer
71
+
72
+ vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
73
+ model_path,
74
+ trust_remote_code=True,
75
+ torch_dtype=dtype
76
+ )
77
+ vl_gpt = vl_gpt.cuda().eval()
78
+
79
+ # single image conversation example
80
+ conversation = [
81
+ {
82
+ "role": "<|User|>",
83
+ "content": "<image>\n<image>\n<|grounding|>In the first image, an object within the red rectangle is marked. Locate the object of the same category in the second image.",
84
+ "images": [
85
+ "images/incontext_visual_grounding_1.jpeg",
86
+ "images/icl_vg_2.jpeg"
87
+ ],
88
+ },
89
+ {"role": "<|Assistant|>", "content": ""},
90
+ ]
91
+
92
+ # conversation = [
93
+ # {
94
+ # "role": "<|User|>",
95
+ # "content": "<image>\n<|ref|>The giraffe at the back.<|/ref|>.",
96
+ # "images": ["./images/visual_grounding_1.jpeg"],
97
+ # },
98
+ # {"role": "<|Assistant|>", "content": ""},
99
+ # ]
100
+
101
+ # load images and prepare for inputs
102
+ pil_images = load_pil_images(conversation)
103
+ print(f"len(pil_images) = {len(pil_images)}")
104
+
105
+ # input_ids = batched_input_ids,
106
+ # attention_mask = batched_attention_mask,
107
+ # labels = batched_labels,
108
+ # images_tiles = batched_images,
109
+ # images_seq_mask = batched_images_seq_mask,
110
+ # images_spatial_crop = batched_images_spatial_crop,
111
+ # sft_format = batched_sft_format,
112
+ # seq_lens = seq_lens
113
+
114
+ prepare_inputs = vl_chat_processor.__call__(
115
+ conversations=conversation,
116
+ images=pil_images,
117
+ force_batchify=True,
118
+ system_prompt=""
119
+ ).to(vl_gpt.device, dtype=dtype)
120
+
121
+ # for key in prepare_inputs.keys():
122
+ # value = prepare_inputs[key]
123
+ # if isinstance(value, list):
124
+ # print(key, len(value), type(value))
125
+ # elif isinstance(value, torch.Tensor):
126
+ # print(key, value.shape, type(value))
127
+
128
+ with torch.no_grad():
129
+ # run image encoder to get the image embeddings
130
+ # inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
131
+
132
+ # incremental_prefilling when using 40G GPU for vl2-small
133
+ inputs_embeds, past_key_values = vl_gpt.incremental_prefilling(
134
+ input_ids=prepare_inputs.input_ids,
135
+ images=prepare_inputs.images,
136
+ images_seq_mask=prepare_inputs.images_seq_mask,
137
+ images_spatial_crop=prepare_inputs.images_spatial_crop,
138
+ attention_mask=prepare_inputs.attention_mask,
139
+ chunk_size=args.chunk_size
140
+ )
141
+
142
+ # run the model to get the response
143
+ outputs = vl_gpt.generate(
144
+ # inputs_embeds=inputs_embeds[:, -1:],
145
+ # input_ids=prepare_inputs.input_ids[:, -1:],
146
+ inputs_embeds=inputs_embeds,
147
+ input_ids=prepare_inputs.input_ids,
148
+ images=prepare_inputs.images,
149
+ images_seq_mask=prepare_inputs.images_seq_mask,
150
+ images_spatial_crop=prepare_inputs.images_spatial_crop,
151
+ attention_mask=prepare_inputs.attention_mask,
152
+ past_key_values=past_key_values,
153
+
154
+ pad_token_id=tokenizer.eos_token_id,
155
+ bos_token_id=tokenizer.bos_token_id,
156
+ eos_token_id=tokenizer.eos_token_id,
157
+ max_new_tokens=512,
158
+
159
+ # do_sample=False,
160
+ # repetition_penalty=1.1,
161
+
162
+ do_sample=True,
163
+ temperature=0.4,
164
+ top_p=0.9,
165
+ repetition_penalty=1.1,
166
+
167
+ use_cache=True,
168
+ )
169
+
170
+ answer = tokenizer.decode(outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), skip_special_tokens=False)
171
+ print(f"{prepare_inputs['sft_format'][0]}", answer)
172
+
173
+ vg_image = parse_ref_bbox(answer, image=pil_images[-1])
174
+ if vg_image is not None:
175
+ vg_image.save("./vg.jpg", format="JPEG", quality=85)
176
+
177
+
178
+ if __name__ == "__main__":
179
+ parser = ArgumentParser()
180
+ parser.add_argument("--model_path", type=str, required=True,
181
+ default="deepseek-ai/deepseek-vl2",
182
+ help="model name or local path to the model")
183
+ parser.add_argument("--chunk_size", type=int, default=512, help="chunk size for the model for prefiiling")
184
+ args = parser.parse_args()
185
+ main(args)
DeepSeek-VL2/.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ transformers==4.38.2
3
+ timm>=0.9.16
4
+ accelerate
5
+ sentencepiece
6
+ attrdict
7
+ einops
8
+
9
+ # for gradio demo
10
+ gradio==3.48.0
11
+ gradio-client==0.6.1
12
+ mdtex2html==1.3.0
13
+ pypinyin==0.50.0
14
+ tiktoken==0.5.2
15
+ tqdm==4.64.0
16
+ colorama==0.4.5
17
+ Pygments==2.12.0
18
+ markdown==3.4.1
19
+ SentencePiece==0.1.96
DeepSeek-VL2/deepseek_vl2.egg-info/requires.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.1
2
+ transformers>=4.38.2
3
+ timm>=0.9.16
4
+ accelerate
5
+ sentencepiece
6
+ attrdict
7
+ einops
8
+
9
+ [gradio]
10
+ gradio==3.48.0
11
+ gradio-client==0.6.1
12
+ mdtex2html==1.3.0
13
+ pypinyin==0.50.0
14
+ tiktoken==0.5.2
15
+ tqdm==4.64.0
16
+ colorama==0.4.5
17
+ Pygments==2.12.0
18
+ markdown==3.4.1
19
+ SentencePiece==0.1.96
20
+
21
+ [lint]
22
+ isort
23
+ black[jupyter]>=22.6.0
24
+ pylint[spelling]>=2.15.0
25
+ flake8
26
+ flake8-bugbear
27
+ flake8-comprehensions
28
+ flake8-docstrings
29
+ flake8-pyi
30
+ flake8-simplify
31
+ ruff
32
+ pyenchant
33
+ pre-commit
DeepSeek-VL2/deepseek_vl2.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ deepseek_vl2
DeepSeek-VL2/deepseek_vl2/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024 DeepSeek.
2
+ #
3
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ # this software and associated documentation files (the "Software"), to deal in
5
+ # the Software without restriction, including without limitation the rights to
6
+ # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7
+ # the Software, and to permit persons to whom the Software is furnished to do so,
8
+ # subject to the following conditions:
9
+ #
10
+ # The above copyright notice and this permission notice shall be included in all
11
+ # copies or substantial portions of the Software.
12
+ #
13
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19
+
20
+
21
+ # check if python version is above 3.10
22
+ import sys
23
+
24
+ if sys.version_info >= (3, 10):
25
+ print("Python version is above 3.10, patching the collections module.")
26
+ # Monkey patch collections
27
+ import collections
28
+ import collections.abc
29
+
30
+ for type_name in collections.abc.__all__:
31
+ setattr(collections, type_name, getattr(collections.abc, type_name))
DeepSeek-VL2/deepseek_vl2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (439 Bytes). View file
 
DeepSeek-VL2/deepseek_vl2/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (429 Bytes). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (348 Bytes). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (403 Bytes). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/configuration_deepseek.cpython-39.pyc ADDED
Binary file (9.47 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/conversation.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/conversation.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/conversation.cpython-39.pyc ADDED
Binary file (6.43 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/modeling_deepseek_vl_v2.cpython-311.pyc ADDED
Binary file (31.5 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/modeling_deepseek_vl_v2.cpython-312.pyc ADDED
Binary file (29.4 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/modeling_deepseek_vl_v2.cpython-39.pyc ADDED
Binary file (17.5 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-310.pyc ADDED
Binary file (18.6 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-311.pyc ADDED
Binary file (33.5 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-312.pyc ADDED
Binary file (30.4 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/processing_deepseek_vl_v2.cpython-39.pyc ADDED
Binary file (18.5 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/siglip_vit.cpython-310.pyc ADDED
Binary file (20.1 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/siglip_vit.cpython-312.pyc ADDED
Binary file (32.6 kB). View file
 
DeepSeek-VL2/deepseek_vl2/models/__pycache__/siglip_vit.cpython-39.pyc ADDED
Binary file (19.8 kB). View file
 
DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/gradio_utils.cpython-312.pyc ADDED
Binary file (2.57 kB). View file
 
DeepSeek-VL2/deepseek_vl2/serve/app_modules/__pycache__/utils.cpython-312.pyc ADDED
Binary file (15 kB). View file
 
DeepSeek-VL2/deepseek_vl2/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
DeepSeek-VL2/deepseek_vl2/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (170 Bytes). View file
 
DeepSeek-VL2/deepseek_vl2/utils/__pycache__/io.cpython-312.pyc ADDED
Binary file (2.6 kB). View file
 
DeepSeek-VL2/deepseek_vl2/utils/__pycache__/io.cpython-39.pyc ADDED
Binary file (1.94 kB). View file
 
DeepSeek-VL2/images/logo.png ADDED
DeepSeek-VL2/images/logo.svg ADDED
DeepSeek-VL2/images/monday.jpg ADDED
DeepSeek-VL2/images/visual_grounding_2.jpg ADDED
DeepSeek-VL2/images/vl2_teaser.jpeg ADDED
VLM2Vec/archive/gather_score_byckpt_aws.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+
5
+ # Define the datasets
6
+ datasets = [
7
+ "ImageNet-1K", "N24News", "HatefulMemes", "VOC2007", "SUN397", "Place365", "ImageNet-A", "ImageNet-R", "ObjectNet", "Country211",
8
+ "OK-VQA", "A-OKVQA", "DocVQA", "InfographicsVQA", "ChartQA", "Visual7W", "ScienceQA", "VizWiz", "GQA", "TextVQA",
9
+ "VisDial", "CIRR", "VisualNews_t2i", "VisualNews_i2t", "MSCOCO_t2i", "MSCOCO_i2t", "NIGHTS", "WebQA", "FashionIQ", "Wiki-SS-NQ", "OVEN", "EDIS",
10
+ "MSCOCO", "RefCOCO", "RefCOCO-Matching", "Visual7W-Pointing"
11
+ ]
12
+
13
+ # Define the root directory containing the experiment directories
14
+ checkpoint_paths = [
15
+ # llava-next
16
+ "/fsx/home/ruimeng/runs/mmeb/mmeb005-llava16_mistral-3.lora8.mmeb20_sub100k-1344.bs1024pergpu128.GCq1p1.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-1000/",
17
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-llava16_mistral-3.lora8.mmeb20_sub100k-1344.bs1024pergpu128.GCq1p1.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-1400/",
18
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-llava16_mistral-1.lora8.mmeb20_sub50k.bs256pergpu32.GCq2p2.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
19
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-llava16_mistral-2.lora8.mmeb20_sub50k.bs1024pergpu128.GCq2p2.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
20
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-e5v-1.lora8.mmeb20_sub50k.bs1024pergpu128.GCq2p2.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
21
+
22
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-llava16_vicuna-1.lora8.mmeb20_sub50k.bs256pergpu32.GCq2p2.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
23
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-llava16_vicuna-2.lora8.mmeb20_sub50k.bs1024pergpu128.GCq2p2.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
24
+
25
+ # scale-up
26
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-scale002.lora8.mmeb17_sub100k_NoMSCOCO.bs1024pergpu128.GCq2p2.phi35.NormTemp002.len256crop9.lr5e5.step5kwarm200.8H100/checkpoint-1500/",
27
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-scale001.lora8.mmeb20_sub100k.bs1024pergpu128.GCq2p2.phi35.NormTemp002.len256crop9.lr2e5.step5kwarm200.8H100/checkpoint-1500/",
28
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-scale001.lora8.mmeb20_sub100k.bs1024pergpu128.GCq2p2.phi35.NormTemp002.len256crop9.lr2e5.step5kwarm200.8H100/checkpoint-2500/",
29
+
30
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-scale002-1.lora8.mmeb17_sub100k_NoMSCOCO.bs1024pergpu128.GCq2p2.phi35.NormTemp002.len256crop9.lr2e5.step2kwarm100.8H100/checkpoint-1000/",
31
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-scale002-1.lora8.mmeb17_sub100k_NoMSCOCO.bs1024pergpu128.GCq2p2.phi35.NormTemp002.len256crop9.lr2e5.step2kwarm100.8H100/checkpoint-1500/",
32
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb005-scale002-1.lora8.mmeb17_sub100k_NoMSCOCO.bs1024pergpu128.GCq2p2.phi35.NormTemp002.len256crop9.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
33
+
34
+ # batch size
35
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-bs1024.fullmodel.mmeb20_sub50k.bs1024pergpu128.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
36
+ # # task
37
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-taskVQA.fullmodel.mmeb20_sub50k.bs64pergpu8.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
38
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-taskRET.fullmodel.mmeb20_sub50k.bs64pergpu8.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
39
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-taskCLS.fullmodel.mmeb20_sub50k.bs64pergpu8.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
40
+ # # lora
41
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-lora8.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
42
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-lora32.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
43
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-lora8_bs1k.mmeb20_sub50k.bs1024pergpu128.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
44
+ # # maxlen
45
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-len128.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len128crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
46
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-len512.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len512crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
47
+ # # step
48
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-step1k.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step1kwarm50.8H100/checkpoint-1000/",
49
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-step4k.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step4kwarm200.8H100/checkpoint-4000/",
50
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-step8k.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step8kwarm400.8H100/checkpoint-8000/",
51
+ # # crop
52
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-crop1.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop1.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
53
+ # "/fsx/sfr/data/MMEB_exp/mmeb004-crop2.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop2.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
54
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-crop9.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq2p2.phi35.NormTemp002.len256crop9.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
55
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-crop16.fullmodel.mmeb20_sub50k.bs256pergpu32.GCq1p1.phi35.NormTemp002.len256crop16.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
56
+ # data size
57
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-lora8_bs1k.mmeb20_sub50k.bs1024pergpu128.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-1000/",
58
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-lora4.mmeb20_sub50k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step2kwarm100.8H100/checkpoint-2000/",
59
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-data25k.fullmodel.mmeb20_sub25k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step4kwarm200.8H100/checkpoint-4000/",
60
+ # "/fsx/home/ruimeng/runs/mmeb/mmeb004-data100k.fullmodel.mmeb20_sub100k.bs256pergpu32.GCq4p4.phi35.NormTemp002.len256crop4.lr2e5.step4kwarm200.8H100/checkpoint-4000/",
61
+ ]
62
+
63
+
64
+ # Function to extract step number from checkpoint directory name
65
+ def extract_step(checkpoint_name):
66
+ match = re.search(r'checkpoint-(\d+)', checkpoint_name)
67
+ return int(match.group(1)) if match else float('inf')
68
+
69
+
70
+ # Dictionary to hold all gathered scores, organized by experiment
71
+ gathered_scores_by_exp = {}
72
+
73
+ # Loop through checkpoint directories
74
+ for checkpoint_path in checkpoint_paths:
75
+ step = extract_step(checkpoint_path)
76
+ experiment_dir = checkpoint_path.split("/")[-3]
77
+
78
+ # Check if it is a checkpoint directory, and a valid checkpoint dir
79
+ if str.isdigit(str(step)):
80
+ # Initialize a dictionary to store scores for this checkpoint
81
+ checkpoint_scores = {"experiment": experiment_dir, "checkpoint": str(step)}
82
+
83
+ # Go through each dataset and check if the corresponding score file exists
84
+ for dataset in datasets:
85
+ score_file = os.path.join(checkpoint_path, f"{dataset}_score.json") # Score file named like DatasetName_score.json
86
+
87
+ # Check if the score file exists
88
+ if os.path.isfile(score_file):
89
+ with open(score_file, "r") as f:
90
+ score_data = json.load(f) # Load the score JSON
91
+ checkpoint_scores[dataset] = score_data.get("acc", "N/A") # Assuming 'acc' is the key for accuracy
92
+ else:
93
+ checkpoint_scores[dataset] = "N/A" # If no score file, set to 'N/A'
94
+
95
+ # Append the scores for this checkpoint to the respective experiment group
96
+ gathered_scores_by_exp[experiment_dir] = checkpoint_scores
97
+
98
+
99
+
100
+ print('\n' * 5)
101
+ # Print gathered scores in a comma-separated format
102
+ header = ["experiment", "checkpoint"] + datasets
103
+ print(",".join(header)) # Print header
104
+
105
+ for experiment, scores in gathered_scores_by_exp.items():
106
+ row = [scores["experiment"], scores["checkpoint"]] + [str(scores[dataset]) for dataset in datasets]
107
+ print(",".join(row)) # Print each row of scores
108
+
109
+
110
+
111
+ header = ["dataset"] + list(gathered_scores_by_exp.keys())
112
+ print(",".join(header)) # Print header
113
+ # Additional Block: Print results per experiment, transposed (dataset per row, step per column)
114
+ # Print dataset names in the first column, and the scores for each checkpoint in subsequent columns
115
+ for dataset in datasets:
116
+ row = []
117
+ for experiment, scores in gathered_scores_by_exp.items():
118
+ row.append(str(scores[dataset]))
119
+ print(",".join([dataset] + row)) # Print header
120
+
121
+
122
+
123
+
124
+ # header = ["dataset"] + list(gathered_scores_by_exp.keys())
125
+ # print(",".join(header)) # Print header
126
+ # # Additional Block: Print results per experiment, transposed (dataset per row, step per column)
127
+ # # Print dataset names in the first column, and the scores for each checkpoint in subsequent columns
128
+ # for dataset in datasets:
129
+ # print(",".join([dataset, str(scores[dataset])]))
130
+ # for experiment, scores in gathered_scores_by_exp.items():
131
+ # print(f"\nResults for {experiment}:")
132
+ #
VLM2Vec/archive/merge.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.arguments import ModelArguments
2
+ from transformers import HfArgumentParser, AutoProcessor
3
+
4
+ from src.model import MMEBModel
5
+ from evaluation.eval_utils import get_pred
6
+
7
+
8
+ def main():
9
+ parser = HfArgumentParser(ModelArguments)
10
+ model_args, = parser.parse_args_into_dataclasses()
11
+ model_args: ModelArguments
12
+
13
+ processor = AutoProcessor.from_pretrained(
14
+ model_args.model_name,
15
+ trust_remote_code=True,
16
+ num_crops=model_args.num_crops,
17
+ )
18
+
19
+ processor.tokenizer.padding_side = "right"
20
+ model = MMEBModel.load(model_args)
21
+ model.encoder._hf_peft_config_loaded = False
22
+ model.encoder.save_pretrained('full_model/', safe_serialization=False)
23
+
24
+
25
+ if __name__ == "__main__":
26
+ main()
VLM2Vec/archive/testset_stats.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+
4
+ import numpy as np
5
+
6
+ from src.arguments import ModelArguments, DataArguments, TrainingArguments
7
+ from transformers import HfArgumentParser, AutoProcessor
8
+ from src.dataset import EvalDataset
9
+ import re
10
+
11
+ def main():
12
+ for arg in sys.argv:
13
+ if arg.startswith("--local-rank="):
14
+ rank = arg.split("=")[1]
15
+ sys.argv.remove(arg)
16
+ sys.argv.append('--local_rank')
17
+ sys.argv.append(rank)
18
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
19
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
20
+ model_args: ModelArguments
21
+ data_args: DataArguments
22
+ training_args: TrainingArguments
23
+
24
+ datasets = [
25
+ "GQA",
26
+ # "ImageNet-1K", "N24News", "HatefulMemes", "VOC2007", "SUN397", "Place365", "ImageNet-A", "ImageNet-R",
27
+ # "ObjectNet", "Country211",
28
+ # "OK-VQA", "A-OKVQA", "DocVQA", "InfographicsVQA", "ChartQA", "Visual7W", "ScienceQA", "VizWiz", "GQA",
29
+ # "TextVQA",
30
+ # "VisDial", "CIRR", "VisualNews_t2i", "VisualNews_i2t", "MSCOCO_t2i", "MSCOCO_i2t", "NIGHTS", "WebQA",
31
+ # "FashionIQ", "Wiki-SS-NQ", "OVEN", "EDIS",
32
+ # "MSCOCO", "RefCOCO", "RefCOCO-Matching", "Visual7W-Pointing"
33
+ ]
34
+
35
+ # ToDo: This part of code is a little bit hacky. Need to refactor later.
36
+ for idx, subset in enumerate(datasets):
37
+ eval_qry_dataset = EvalDataset(
38
+ data_args=data_args,
39
+ model_args=model_args,
40
+ subset=subset,
41
+ text_field="qry_text",
42
+ img_path_field="qry_img_path",
43
+ )
44
+ eval_tgt_dataset = EvalDataset(
45
+ data_args=data_args,
46
+ model_args=model_args,
47
+ subset=subset,
48
+ text_field="tgt_text",
49
+ img_path_field="tgt_img_path",
50
+ )
51
+ tgttokens = []
52
+ tgtstr_lens = []
53
+ for tgt in eval_tgt_dataset:
54
+ # print(tgt)
55
+ tokens = re.split('[^a-zA-Z]', tgt[0])
56
+ tgttokens.append(tokens)
57
+ tgtstr_lens.append(len(tokens))
58
+ pass
59
+
60
+ print(f'dataset: {subset}')
61
+ print(f'tgt-avg-len: {np.mean(tgtstr_lens)}')
62
+ pass
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
VLM2Vec/evaluation/eval_flickr.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import HfArgumentParser, AutoProcessor
2
+
3
+ from src.arguments import ModelArguments, DataArguments, TrainingArguments
4
+ from src.model import MMEBModel
5
+ from src.dataset import FlickrDataset
6
+ from src.collator import EvalCollator
7
+ from src.utils import load_processor
8
+
9
+ from torch.utils.data import DataLoader
10
+ import torch
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import pickle
14
+ import os
15
+ from datasets import load_dataset
16
+ from eval_utils import get_pred
17
+
18
+
19
+ def main():
20
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
21
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
22
+ model_args: ModelArguments
23
+ data_args: DataArguments
24
+ training_args: TrainingArguments
25
+
26
+ processor = load_processor(model_args)
27
+
28
+ eval_img_dataset = FlickrDataset(
29
+ modality="image", model_backbone=model_args.model_backbone
30
+ )
31
+ eval_txt_dataset = FlickrDataset(
32
+ modality="text", model_backbone=model_args.model_backbone
33
+ )
34
+ eval_collator = EvalCollator(
35
+ data_args=data_args,
36
+ model_args=model_args,
37
+ processor=processor,
38
+ )
39
+
40
+ model = MMEBModel.load(model_args)
41
+ model.eval()
42
+ model = model.to(training_args.device, dtype=torch.bfloat16)
43
+
44
+ eval_img_loader = DataLoader(
45
+ eval_img_dataset,
46
+ batch_size=training_args.per_device_eval_batch_size,
47
+ collate_fn=eval_collator,
48
+ shuffle=False,
49
+ drop_last=False,
50
+ num_workers=training_args.dataloader_num_workers,
51
+ )
52
+ eval_txt_loader = DataLoader(
53
+ eval_txt_dataset,
54
+ batch_size=training_args.per_device_eval_batch_size,
55
+ collate_fn=eval_collator,
56
+ shuffle=False,
57
+ drop_last=False,
58
+ num_workers=training_args.dataloader_num_workers,
59
+ )
60
+
61
+ encode_img_path = os.path.join(data_args.encode_output_path, f"flickr_image_1K-crop{model_args.num_crops}")
62
+ encode_txt_path = os.path.join(data_args.encode_output_path, f"flickr_text_1K-crop{model_args.num_crops}")
63
+
64
+ encoded_tensor = []
65
+ with torch.no_grad():
66
+ for batch in tqdm(eval_img_loader, desc="Encode image"):
67
+ batch = {key: value.to(training_args.device) for key, value in batch.items()}
68
+ output = model(qry=batch)
69
+ encoded_tensor.append(output["qry_reps"].cpu().detach().float().numpy())
70
+ encoded_tensor = np.concatenate(encoded_tensor)
71
+ with open(encode_img_path, 'wb') as f:
72
+ pickle.dump((encoded_tensor, eval_img_dataset.image_names), f)
73
+
74
+ encoded_tensor = []
75
+ with torch.no_grad():
76
+ for batch in tqdm(eval_txt_loader, desc="Encode text"):
77
+ batch = {key: value.to(training_args.device) for key, value in batch.items()}
78
+ output = model(qry=batch)
79
+ encoded_tensor.append(output["qry_reps"].cpu().detach().float().numpy())
80
+ encoded_tensor = np.concatenate(encoded_tensor)
81
+ with open(encode_txt_path, 'wb') as f:
82
+ pickle.dump((encoded_tensor, eval_txt_dataset.image_names), f)
83
+
84
+ with open(encode_img_path, 'rb') as f:
85
+ img_tensor, i2t_name = pickle.load(f)
86
+ img_tensor = torch.from_numpy(img_tensor)
87
+ with open(encode_txt_path, 'rb') as f:
88
+ txt_tensor, t2i_name = pickle.load(f)
89
+ txt_tensor = torch.from_numpy(txt_tensor)
90
+
91
+ # I -> T
92
+ similarity_matrix = torch.matmul(img_tensor, txt_tensor.T)
93
+ recall_at_k = {1: 0, 5: 0, 10: 0}
94
+ sorted_indices = torch.argsort(similarity_matrix, dim=1, descending=True)
95
+ for idx, file_name in enumerate(i2t_name):
96
+ top_k_indices = sorted_indices[idx, :10] # Get top-10 indices
97
+ top_k_file_names = [t2i_name[i.item()] for i in top_k_indices]
98
+ for k in [1, 5, 10]:
99
+ if file_name in top_k_file_names[:k]:
100
+ recall_at_k[k] += 1
101
+
102
+ for k in [1, 5, 10]:
103
+ recall_at_k[k] = recall_at_k[k] / len(i2t_name)
104
+ print(f"\033[91m Recall@{k}: {recall_at_k[k]:.4f}\033[0m")
105
+
106
+
107
+ # T -> I
108
+ similarity_matrix = torch.matmul(txt_tensor, img_tensor.T)
109
+ recall_at_k = {1: 0, 5: 0, 10: 0}
110
+ sorted_indices = torch.argsort(similarity_matrix, dim=1, descending=True)
111
+ for idx, file_name in enumerate(t2i_name):
112
+ top_k_indices = sorted_indices[idx, :10]
113
+ top_k_file_names = [i2t_name[i.item()] for i in top_k_indices]
114
+ for k in [1, 5, 10]:
115
+ if file_name in top_k_file_names[:k]:
116
+ recall_at_k[k] += 1
117
+
118
+ for k in [1, 5, 10]:
119
+ recall_at_k[k] = recall_at_k[k] / len(t2i_name)
120
+ print(f"\033[91m Recall@{k}: {recall_at_k[k]:.4f}\033[0m")
121
+
122
+
123
+ if __name__ == "__main__":
124
+ main()
VLM2Vec/figures/example.jpg ADDED
VLM2Vec/grad_cache/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ try:
2
+ from .grad_cache import GradCache
3
+ except ModuleNotFoundError:
4
+ pass
VLM2Vec/grad_cache/cachex/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .functional import chunk_encode, cache_grad, unchunk_args
2
+ from .tree_utils import tree_chunk, tree_unchunk
3
+
VLM2Vec/grad_cache/context_managers.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.checkpoint import get_device_states, set_device_states
3
+
4
+
5
+ class RandContext:
6
+ def __init__(self, *tensors):
7
+ self.fwd_cpu_state = torch.get_rng_state()
8
+ self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
9
+
10
+ def __enter__(self):
11
+ self._fork = torch.random.fork_rng(
12
+ devices=self.fwd_gpu_devices,
13
+ enabled=True
14
+ )
15
+ self._fork.__enter__()
16
+ torch.set_rng_state(self.fwd_cpu_state)
17
+ set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
18
+
19
+ def __exit__(self, exc_type, exc_val, exc_tb):
20
+ self._fork.__exit__(exc_type, exc_val, exc_tb)
21
+ self._fork = None
VLM2Vec/grad_cache/functional.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from typing import Callable, Union, Tuple, Any
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch import distributed as dist
7
+
8
+ from .context_managers import RandContext
9
+
10
+
11
+ def cached(func: Callable[..., Tensor]):
12
+ """
13
+ A decorator that takes a pytorch call function into a cached compatible version.
14
+ :param func: A function that calls the pytorch and return representation tensor.
15
+ :return: A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for
16
+ the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor.
17
+ """
18
+ @wraps(func)
19
+ def cache_func(*args, **kwargs):
20
+ rnd_state = RandContext()
21
+ with torch.no_grad():
22
+ reps_no_grad = func(*args, **kwargs)
23
+ if isinstance(reps_no_grad, Tensor):
24
+ reps_no_grad = (reps_no_grad, )
25
+ else:
26
+ assert all(isinstance(v, Tensor) for v in reps_no_grad)
27
+ leaf_reps = tuple(t.detach().requires_grad_() for t in reps_no_grad)
28
+
29
+ @wraps(func)
30
+ def forward_backward_func(cache_reps: Union[Tensor, Tuple[Tensor]]):
31
+ with rnd_state:
32
+ reps = func(*args, **kwargs)
33
+ if isinstance(reps, Tensor):
34
+ reps = (reps,)
35
+ if isinstance(cache_reps, Tensor):
36
+ cache_reps = (cache_reps,)
37
+ assert len(reps) == len(cache_reps)
38
+
39
+ surrogate = sum(map(lambda u, v: torch.dot(u.flatten(), v.grad.flatten()), reps, cache_reps), 0)
40
+ surrogate.backward()
41
+
42
+ return leaf_reps + (forward_backward_func,)
43
+ return cache_func
44
+
45
+
46
+ def _cat_tensor_list(xx):
47
+ if isinstance(xx, list) and len(xx) > 0 and all(isinstance(x, Tensor) for x in xx):
48
+ return torch.cat(xx)
49
+ else:
50
+ return xx
51
+
52
+
53
+ def cat_input_tensor(func: Callable[..., Tensor]):
54
+ """
55
+ A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor
56
+ on the 0 dimension. This can come in handy dealing with results of representation tensors from multiple
57
+ cached forward.
58
+ :param func: A loss function
59
+ :return: Decorated loss function for cached results.
60
+ """
61
+ @wraps(func)
62
+ def cat_f(*args, **kwargs):
63
+ args_cat = [_cat_tensor_list(x) for x in args]
64
+ kwargs_cat = dict((k, _cat_tensor_list(v)) for k, v in kwargs.values())
65
+ return func(*args_cat, **kwargs_cat)
66
+ return cat_f
67
+
68
+
69
+ def _maybe_gather_tensor(t: Any, axis: int):
70
+ if not isinstance(t, Tensor):
71
+ return t
72
+ gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())]
73
+ dist.all_gather(gathered, t)
74
+ gathered[dist.get_rank()] = t
75
+ return torch.cat(gathered, dim=axis)
76
+
77
+
78
+ def gather_input_tensor(func: Callable[..., Tensor], axis=0):
79
+ """
80
+ A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis.
81
+ Intended to be used with distributed contrastive learning loss.
82
+ :param func: A loss function
83
+ :param axis: The axis the gathered tensors are concatenated.
84
+ :return: Decorated loss function for distributed training.
85
+ """
86
+ @wraps(func)
87
+ def f(*args, **kwargs):
88
+ args_gathered = [_maybe_gather_tensor(x, axis=axis) for x in args]
89
+ kwargs_gathered = dict((k, _maybe_gather_tensor(v, axis=axis)) for k, v in kwargs.values())
90
+ return func(*args_gathered, **kwargs_gathered)
91
+ return f
VLM2Vec/grad_cache/grad_cache.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union, Callable, Any
2
+ from contextlib import nullcontext
3
+ from itertools import repeat
4
+ from collections import UserDict
5
+ import logging
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+ from torch.cuda.amp import GradScaler, autocast
10
+ from grad_cache.context_managers import RandContext
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class GradCache:
16
+ """
17
+ Gradient Cache class. Implements input chunking, first graph-less forward pass, Gradient Cache creation, second
18
+ forward & backward gradient computation. Optimizer step is not included. Native torch automatic mixed precision is
19
+ supported. User needs to handle gradient unscaling and scaler update after a gradeitn cache step.
20
+ """
21
+ def __init__(
22
+ self,
23
+ models: List[nn.Module],
24
+ chunk_sizes: Union[int, List[int]],
25
+ loss_fn: Callable[..., Tensor],
26
+ split_input_fn: Callable[[Any, int], Any] = None,
27
+ get_rep_fn: Callable[..., Tensor] = None,
28
+ fp16: bool = False,
29
+ scaler: GradScaler = None,
30
+ ):
31
+ """
32
+ Initialize the Gradient Cache class instance.
33
+ :param models: A list of all encoder models to be updated by the current cache.
34
+ :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model.
35
+ :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and
36
+ arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations
37
+ in the autograd graph, which are later relied upon to create the gradient cache.
38
+ :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this
39
+ class will try its best to split the inputs of supported types. See `split_inputs` function.
40
+ :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If
41
+ not provided, the generic output is assumed to be the representation tensor.
42
+ :param fp16: If True, run mixed precision training, which requires scaler to also be set.
43
+ :param scaler: A GradScaler object for automatic mixed precision training.
44
+ """
45
+ self.models = models
46
+
47
+ if isinstance(chunk_sizes, int):
48
+ self.chunk_sizes = [chunk_sizes for _ in range(len(models))]
49
+ else:
50
+ self.chunk_sizes = chunk_sizes
51
+
52
+ self.split_input_fn = split_input_fn
53
+ self.get_rep_fn = get_rep_fn
54
+ self.loss_fn = loss_fn
55
+
56
+ if fp16:
57
+ assert scaler is not None, "mixed precision training requires a gradient scaler passed in"
58
+
59
+ self.fp16 = fp16
60
+ self.scaler = scaler
61
+
62
+ self._get_input_tensors_strict = False
63
+
64
+ def __call__(self, *args, **kwargs):
65
+ """
66
+ Call the cache_step function.
67
+ :return: Current step loss.
68
+ """
69
+ return self.cache_step(*args, **kwargs)
70
+
71
+ def split_inputs(self, model_input, chunk_size: int) -> List:
72
+ """
73
+ Split input into chunks. Will call user provided `split_input_fn` if specified. Otherwise,
74
+ it can handle input types of tensor, list of tensors and dictionary of tensors.
75
+ :param model_input: Generic model input.
76
+ :param chunk_size: Size of each chunk.
77
+ :return: A list of chunked model input.
78
+ """
79
+ # delegate splitting to user provided function
80
+ if self.split_input_fn is not None:
81
+ return self.split_input_fn(model_input, chunk_size)
82
+
83
+ if isinstance(model_input, (dict, UserDict)) and all(isinstance(x, Tensor) for x in model_input.values()):
84
+ keys = list(model_input.keys())
85
+ chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
86
+ return [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]
87
+
88
+ elif isinstance(model_input, list) and all(isinstance(x, Tensor) for x in model_input):
89
+ chunked_x = [t.split(chunk_size, dim=0) for t in model_input]
90
+ return [list(s) for s in zip(*chunked_x)]
91
+
92
+ elif isinstance(model_input, Tensor):
93
+ return list(model_input.split(chunk_size, dim=0))
94
+
95
+ elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]:
96
+ args_chunks = self.split_inputs(model_input[0], chunk_size)
97
+ kwargs_chunks = self.split_inputs(model_input[1], chunk_size)
98
+ return list(zip(args_chunks, kwargs_chunks))
99
+
100
+ else:
101
+ raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}')
102
+
103
+ def get_input_tensors(self, model_input) -> List[Tensor]:
104
+ """
105
+ Recursively go through model input and grab all tensors, which are then used to record current device random
106
+ states. This method will do its best to parse types of Tensor, tuple, list, dict and UserDict. Other types will
107
+ be ignored unless self._get_input_tensors_strict is set to True, in which case an exception will be raised.
108
+ :param model_input: input to model
109
+ :return: all torch tensors in model_input
110
+ """
111
+ if isinstance(model_input, Tensor):
112
+ return [model_input]
113
+
114
+ elif isinstance(model_input, (list, tuple)):
115
+ return sum((self.get_input_tensors(x) for x in model_input), [])
116
+
117
+ elif isinstance(model_input, (dict, UserDict)):
118
+ return sum((self.get_input_tensors(x) for x in model_input.values()), [])
119
+
120
+ elif self._get_input_tensors_strict:
121
+ raise NotImplementedError(f'get_input_tensors not implemented for type {type(model_input)}')
122
+
123
+ else:
124
+ return []
125
+
126
+ def model_call(self, model: nn.Module, model_input):
127
+ """
128
+ Literally call the model's __call__ method.
129
+ :param model: model to be called
130
+ :param model_input: input to the model call
131
+ :return: model output
132
+ """
133
+ with autocast() if self.fp16 else nullcontext():
134
+ if isinstance(model_input, Tensor):
135
+ return model(model_input)
136
+ elif isinstance(model_input, list):
137
+ return model(*model_input)
138
+ elif isinstance(model_input, (dict, UserDict)):
139
+ return model(**model_input)
140
+ elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]:
141
+ model_args, model_kwargs = model_input
142
+ return model(*model_args, **model_kwargs)
143
+ else:
144
+ raise NotImplementedError
145
+
146
+ def get_reps(self, model_out) -> Tensor:
147
+ """
148
+ Return representation tensor from generic model output
149
+ :param model_out: generic model output
150
+ :return: a single tensor corresponding to the model representation output
151
+ """
152
+ if self.get_rep_fn is not None:
153
+ return self.get_rep_fn(model_out)
154
+ else:
155
+ return model_out
156
+
157
+ def compute_loss(self, *reps: Tensor, **loss_kwargs) -> Tensor:
158
+ """
159
+ Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models
160
+ registered in this GradCache class instance.
161
+ :param reps: Representations for computing the loss.
162
+ :param loss_kwargs: Keyword arguments input to the loss function.
163
+ :return: the loss tensor.
164
+ """
165
+ loss = self.loss_fn(*reps, **loss_kwargs)
166
+ return loss
167
+
168
+ def forward_no_grad(
169
+ self,
170
+ model: nn.Module,
171
+ model_inputs,
172
+ ) -> [Tensor, List[RandContext]]:
173
+ """
174
+ The first forward pass without gradient computation.
175
+ :param model: Encoder model.
176
+ :param model_inputs: Model input already broken into chunks.
177
+ :return: A tuple of a) representations and b) recorded random states.
178
+ """
179
+ rnd_states = []
180
+ model_reps = []
181
+
182
+ with torch.no_grad():
183
+ for x in model_inputs:
184
+ rnd_states.append(RandContext(*self.get_input_tensors(x)))
185
+ y = self.model_call(model, x)
186
+ model_reps.append(self.get_reps(y))
187
+
188
+ # concatenate all sub-batch representations
189
+ model_reps = torch.cat(model_reps, dim=0)
190
+ return model_reps, rnd_states
191
+
192
+ def build_cache(self, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]:
193
+ """
194
+ Compute the gradient cache
195
+ :param reps: Computed representations from all encoder models
196
+ :param loss_kwargs: Extra keyword arguments to the loss function
197
+ :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor
198
+ """
199
+ reps = [r.detach().requires_grad_() for r in reps]
200
+ with autocast() if self.fp16 else nullcontext():
201
+ loss = self.compute_loss(*reps, **loss_kwargs)
202
+
203
+ if self.fp16:
204
+ self.scaler.scale(loss).backward()
205
+ else:
206
+ loss.backward()
207
+
208
+ cache = [r.grad for r in reps]
209
+
210
+ return cache, loss.detach()
211
+
212
+ def forward_backward(
213
+ self,
214
+ model: nn.Module,
215
+ model_inputs,
216
+ cached_gradients: List[Tensor],
217
+ random_states: List[RandContext],
218
+ no_sync_except_last: bool = False
219
+ ):
220
+ """
221
+ Run the second forward and the backward pass to compute gradient for a model.
222
+ :param model: Encoder model.
223
+ :param model_inputs: Chunked input to the encoder model.
224
+ :param cached_gradients: Chunked gradient cache tensor for each input.
225
+ :param random_states: Each input's device random state during the first forward.
226
+ :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes
227
+ for the last sub-batch's forward-backward pass.
228
+ """
229
+ if no_sync_except_last:
230
+ sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext]
231
+ else:
232
+ sync_contexts = [nullcontext for _ in range(len(model_inputs))]
233
+
234
+ for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts):
235
+ with sync_context():
236
+ with state:
237
+ y = self.model_call(model, x)
238
+ reps = self.get_reps(y)
239
+
240
+ surrogate = torch.dot(reps.flatten(), gradient.flatten())
241
+ surrogate.backward()
242
+
243
+ def cache_step(
244
+ self,
245
+ *model_inputs,
246
+ no_sync_except_last: bool = False,
247
+ **loss_kwargs
248
+ ) -> Tensor:
249
+ """
250
+ Run a cached step to compute gradient over the inputs.
251
+ :param model_inputs: Input to each encoder model. Should be in similar order as the class's model.
252
+ :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction
253
+ across processes for the last sub-batch's forward-backward pass.
254
+ :param loss_kwargs: Additional keyword arguments to the loss function.
255
+ :return: The current's loss.
256
+ """
257
+ all_reps = []
258
+ all_rnd_states = []
259
+
260
+ if no_sync_except_last:
261
+ assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \
262
+ 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \
263
+ 'proper initializations.'
264
+
265
+ model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
266
+
267
+ for model, x in zip(self.models, model_inputs):
268
+ model_reps, rnd_states = self.forward_no_grad(model, x)
269
+ all_reps.append(model_reps)
270
+ all_rnd_states.append(rnd_states)
271
+
272
+ cache, loss = self.build_cache(*all_reps, **loss_kwargs)
273
+ cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)]
274
+
275
+ for model, x, model_cache, rnd_states in zip(
276
+ self.models, model_inputs, cache, all_rnd_states):
277
+ self.forward_backward(model, x, model_cache, rnd_states, no_sync_except_last=no_sync_except_last)
278
+
279
+ return loss
VLM2Vec/grad_cache/loss.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.nn import functional as F
4
+ from torch import distributed as dist
5
+
6
+ from src import dist_utils
7
+
8
+
9
+ class InExampleContrastiveLoss:
10
+ """
11
+ Categorization loss: cross_entropy of 1 out of K classes (target labels)
12
+ x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim]
13
+ """
14
+ def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs):
15
+ self.target_per_qry = n_hard_negatives + 1
16
+ self.temperature = temperature
17
+ self.ndim = ndim
18
+
19
+ def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'):
20
+ # print("gather InExampleContrastiveLoss")
21
+ if torch.distributed.is_initialized():
22
+ x = dist_utils.dist_gather(x)
23
+ y = dist_utils.dist_gather(y)
24
+ bsz, ndim = x.size(0), x.size(1)
25
+ target = torch.zeros(bsz, dtype=torch.long, device=x.device)
26
+ if self.ndim:
27
+ ndim = self.ndim
28
+ x = x[:, :ndim]
29
+ y = y[:, :ndim]
30
+ logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature
31
+ preds = torch.argmax(logits, dim=-1)
32
+ loss = F.cross_entropy(logits, target, reduction=reduction)
33
+ loss_detail = {"logits": logits, "labels": target, "preds": preds}
34
+ return loss, loss_detail
35
+
36
+
37
+ class SimpleContrastiveLoss:
38
+ def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, *args, **kwargs):
39
+ self.target_per_qry = n_hard_negatives + 1
40
+ self.temperature = temperature
41
+
42
+ def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
43
+ # print("gather SimpleContrastiveLoss")
44
+ if target is None:
45
+ assert x.size(0) * self.target_per_qry == y.size(0)
46
+ target = torch.arange(0, y.size(0), step=self.target_per_qry, dtype=torch.long, device=x.device)
47
+ logits = torch.matmul(x, y.transpose(0, 1)) * self.temperature
48
+ preds = torch.argmax(logits, dim=-1)
49
+ loss = F.cross_entropy(logits, target, reduction=reduction)
50
+ loss_detail = {"logits": logits, "labels": target, "preds": preds}
51
+ return loss, loss_detail
52
+
53
+
54
+ class DistributedContrastiveLoss(SimpleContrastiveLoss):
55
+ def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, *args, **kwargs):
56
+ assert dist.is_initialized(), "Distributed training has not been properly initialized."
57
+
58
+ super().__init__(n_hard_negatives=n_hard_negatives, temperature=temperature)
59
+ self.world_size = dist.get_world_size()
60
+ self.rank = dist.get_rank()
61
+
62
+ def __call__(self, x: Tensor, y: Tensor, **kwargs):
63
+ # print("gather DistributedContrastiveLoss")
64
+ dist_x = self.gather_tensor(x)
65
+ dist_y = self.gather_tensor(y)
66
+
67
+ return super().__call__(dist_x, dist_y, **kwargs)
68
+
69
+ def gather_tensor(self, t):
70
+ gathered = [torch.empty_like(t) for _ in range(self.world_size)]
71
+ dist.all_gather(gathered, t)
72
+ gathered[self.rank] = t
73
+ return torch.cat(gathered, dim=0)
74
+
75
+
76
+ LossName2LossCls = {
77
+ "inexample_contrastive": InExampleContrastiveLoss,
78
+ "inbatch_contrastive": SimpleContrastiveLoss,
79
+ "distributed_inbatch_contrastive": DistributedContrastiveLoss,
80
+ }
VLM2Vec/grad_cache/minigc_cmd.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ #### Prerequisite
4
+ ```bash
5
+ ENV_PATH=/export/share/ruimeng/env/anaconda/envs/llm/bin/ninja
6
+ export PATH="${ENV_PATH}/:$PATH"
7
+
8
+ export NCCL_DEBUG=WARN
9
+ export HF_DATASETS_CACHE=/export/xgen-embedding/data/.hfdata_cache
10
+ export TRANSFORMERS_CACHE=/export/xgen-embedding/data/.hfmodel_cache/
11
+ export TOKENIZERS_PARALLELISM=true
12
+ export WANDB_DISABLED=false
13
+ export WANDB_PROJECT=mini-gradcache
14
+ export WANDB_API_KEY=local-d64a4127e8d4a1782aedbb72e76080b3dfbf89dd
15
+ export WANDB_BASE_URL=https://salesforceairesearch.wandb.io
16
+ ```
17
+
18
+ ```bash
19
+ # gpu0-3, DDP4-bs4096-accum4, 29922MB, hang at epoch34
20
+ export EXP_NAME=GC-4gpu-bs4096-accum16-step10k
21
+ export EXP_DIR=/export/xgen-embedding/runs/ruimeng/minimal_gc/$EXP_NAME
22
+ export WANDB_DIR=$EXP_DIR/wandb
23
+ export WANDB_NAME=$EXP_NAME
24
+ export WORLD_SIZE=4
25
+ mkdir -p $EXP_DIR/wandb
26
+ rm -rf $EXP_DIR/*
27
+ cd /export/home/project/search/xgen-embedding/
28
+ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=4403 --max_restarts=0 mini_gc.py --model_name_or_path bert-base-uncased --output_dir $EXP_DIR --q_len 128 --d_len 256 --batch_size 4096 --chunk_sizes 256 2>&1 | tee $EXP_DIR/train.log
29
+
30
+
31
+ # gpu0-3, DDP4-bs256-accum4, 11818MB
32
+ export EXP_NAME=GC-4gpu-bs256-accum4-step10k
33
+ export EXP_DIR=/export/xgen-embedding/runs/ruimeng/minimal_gc/$EXP_NAME
34
+ export WANDB_DIR=$EXP_DIR/wandb
35
+ export WANDB_NAME=$EXP_NAME
36
+ export WORLD_SIZE=4
37
+ mkdir -p $EXP_DIR/wandb
38
+ rm -rf $EXP_DIR/*
39
+ cd /export/home/project/search/xgen-embedding/
40
+ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=4403 --max_restarts=0 mini_gc.py --model_name_or_path bert-base-uncased --output_dir $EXP_DIR --q_len 128 --d_len 256 --batch_size 64 --chunk_sizes 16 2>&1 | tee $EXP_DIR/train.log
41
+
42
+
43
+
44
+ # gpu45, DDP2-bs256-accum2, 15742MB
45
+ export EXP_NAME=GC-2gpu-bs256-accum2-step10k
46
+ export EXP_DIR=/export/xgen-embedding/runs/ruimeng/minimal_gc/$EXP_NAME
47
+ export WANDB_DIR=$EXP_DIR/wandb
48
+ export WANDB_NAME=$EXP_NAME
49
+ export WORLD_SIZE=1
50
+ mkdir -p $EXP_DIR/wandb
51
+ rm -rf $EXP_DIR/*
52
+ cd /export/home/project/search/xgen-embedding/
53
+ CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=2245 --max_restarts=0 mini_gc.py --model_name_or_path bert-base-uncased --output_dir $EXP_DIR --q_len 128 --d_len 256 --batch_size 128 --chunk_sizes 64 2>&1 | tee $EXP_DIR/train.log
54
+
55
+
56
+ # gpu6, bs256-accum4, 9GB
57
+ export EXP_NAME=GC-1gpu-bs256-accum4-step10k
58
+ export EXP_DIR=/export/xgen-embedding/runs/ruimeng/minimal_gc/$EXP_NAME
59
+ export WANDB_DIR=$EXP_DIR/wandb
60
+ export WANDB_NAME=$EXP_NAME
61
+ export WORLD_SIZE=1
62
+ mkdir -p $EXP_DIR/wandb
63
+ rm -rf $EXP_DIR/*
64
+ cd /export/home/project/search/xgen-embedding/
65
+ CUDA_VISIBLE_DEVICES=6 python -m mini_gc --model_name_or_path bert-base-uncased --output_dir $EXP_DIR --q_len 128 --d_len 256 --batch_size 256 --chunk_sizes 64 2>&1 | tee $EXP_DIR/train.log
66
+
67
+
68
+ # gpu6, bs256-accum2, 18GB
69
+ export EXP_NAME=GC-1gpu-bs256-accum2-step10k
70
+ export EXP_DIR=/export/xgen-embedding/runs/ruimeng/minimal_gc/$EXP_NAME
71
+ export WANDB_DIR=$EXP_DIR/wandb
72
+ export WANDB_NAME=$EXP_NAME
73
+ export WORLD_SIZE=1
74
+ mkdir -p $EXP_DIR/wandb
75
+ rm -rf $EXP_DIR/*
76
+ cd /export/home/project/search/xgen-embedding/
77
+ CUDA_VISIBLE_DEVICES=6 python -m mini_gc --model_name_or_path bert-base-uncased --output_dir $EXP_DIR --q_len 128 --d_len 256 --batch_size 256 --chunk_sizes 128 2>&1 | tee $EXP_DIR/train.log
78
+
79
+
80
+ # gpu7, bs256-accum1, 38012MB
81
+ export EXP_NAME=GC-1gpu-bs256-accum1-step10k-baseline
82
+ export EXP_DIR=/export/xgen-embedding/runs/ruimeng/minimal_gc/$EXP_NAME
83
+ export WANDB_DIR=$EXP_DIR/wandb
84
+ export WANDB_NAME=$EXP_NAME
85
+ export WORLD_SIZE=1
86
+ mkdir -p $EXP_DIR/wandb
87
+ rm -rf $EXP_DIR/*
88
+ cd /export/home/project/search/xgen-embedding/
89
+ CUDA_VISIBLE_DEVICES=7 python -m mini_gc --model_name_or_path bert-base-uncased --output_dir $EXP_DIR --q_len 128 --d_len 256 --batch_size 256 --chunk_sizes -1 2>&1 | tee $EXP_DIR/train.log
90
+ ```
VLM2Vec/scripts/llava_next/demo.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.model import MMEBModel
2
+ from src.arguments import ModelArguments
3
+ from src.utils import load_processor
4
+
5
+ import torch
6
+ from transformers import HfArgumentParser, AutoProcessor
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+
11
+ model_args = ModelArguments(
12
+ model_name='TIGER-Lab/VLM2Vec-LLaVa-Next',
13
+ pooling='last',
14
+ normalize=True,
15
+ model_backbone='llava_next')
16
+
17
+ processor = load_processor(model_args)
18
+
19
+ model = MMEBModel.load(model_args)
20
+ model.eval()
21
+ model = model.to('cuda', dtype=torch.bfloat16)
22
+
23
+ # Image + Text -> Text
24
+ inputs = processor(text='<image> Represent the given image with the following question: What is in the image',
25
+ images=Image.open('figures/example.jpg'),
26
+ return_tensors="pt")
27
+ inputs = {key: value.to('cuda') for key, value in inputs.items()}
28
+ qry_output = model(qry=inputs)["qry_reps"]
29
+
30
+ string = 'A cat and a dog'
31
+ inputs = processor(text=string,
32
+ images=None,
33
+ return_tensors="pt")
34
+ inputs = {key: value.to('cuda') for key, value in inputs.items()}
35
+ tgt_output = model(tgt=inputs)["tgt_reps"]
36
+ print(string, '=', model.compute_similarity(qry_output, tgt_output))
37
+ ## A cat and a dog = tensor([[0.4414]], device='cuda:0', dtype=torch.bfloat16)
38
+
39
+ string = 'A cat and a tiger'
40
+ inputs = processor(text=string,
41
+ images=None,
42
+ return_tensors="pt")
43
+ inputs = {key: value.to('cuda') for key, value in inputs.items()}
44
+ tgt_output = model(tgt=inputs)["tgt_reps"]
45
+ print(string, '=', model.compute_similarity(qry_output, tgt_output))
46
+ ## A cat and a tiger = tensor([[0.3555]], device='cuda:0', dtype=torch.bfloat16)
VLM2Vec/scripts/llava_next/run_eval_flickr_llava_next.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=../VLM2Vec/:$PYTHONPATH
2
+
3
+
4
+ CUDA_VISIBLE_DEVICES=0 python evaluation/eval_flickr.py \
5
+ --model_name TIGER-Lab/VLM2Vec-LLaVa-Next \
6
+ --model_backbone llava_next \
7
+ --max_len 256 \
8
+ --pooling last --normalize True \
9
+ --per_device_eval_batch_size 16 \
10
+ --encode_output_path /home/ziyan/MMEB_eval/flickr_new/
11
+
12
+
13
+ ## I -> T:
14
+ #Recall@1: 0.9400
15
+ #Recall@5: 0.9930
16
+ #Recall@10: 0.9960
17
+ #
18
+ ## T -> I
19
+ #Recall@1: 0.8024
20
+ #Recall@5: 0.9494
21
+ #Recall@10: 0.9736
VLM2Vec/src/arguments.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from transformers import TrainingArguments
3
+ from typing import List
4
+
5
+
6
+ @dataclass
7
+ class ModelArguments:
8
+ model_name: str = field(
9
+ metadata={"help": "huggingface model name or path"}
10
+ )
11
+ model_backbone: str = field(
12
+ metadata={"help": "vlm backbone"}
13
+ )
14
+ processor_name: str = field(
15
+ default=None, metadata={"help": "processor_name, huggingface model name or path"}
16
+ )
17
+ model_type: str = field(
18
+ default=None, metadata={"help": "lavis model type"}
19
+ )
20
+ checkpoint_path: str = field(
21
+ default=None, metadata={"help": "a local model path"}
22
+ )
23
+ pooling: str = field(
24
+ default='last',
25
+ metadata={"help": "pooling method for encoder"}
26
+ )
27
+ normalize: bool = field(
28
+ default=False,
29
+ metadata={"help": "normalize query and passage representations"}
30
+ )
31
+ temperature: float = field(
32
+ default=0.02,
33
+ metadata={"help": "temperature for softmax"}
34
+ )
35
+ lora: bool = field(
36
+ default=False, metadata={"help": "do parameter-efficient fine-tuning with lora"}
37
+ )
38
+ lora_r: int = field(
39
+ default=16,
40
+ metadata={"help": "lora r"}
41
+ )
42
+ lora_alpha: int = field(
43
+ default=64,
44
+ metadata={"help": "lora alpha"}
45
+ )
46
+ lora_dropout: float = field(
47
+ default=0.1,
48
+ metadata={"help": "lora dropout"}
49
+ )
50
+ lora_target_modules: str = field(
51
+ default="qkv_proj,o_proj,gate_up_proj,down_proj,k_proj,q_proj,out_proj,v_proj",
52
+ metadata={"help": "lora target modules"}
53
+ )
54
+ num_crops: int = field(
55
+ default=16,
56
+ metadata={"help": "number of crops used in image encoder"}
57
+ )
58
+
59
+
60
+ @dataclass
61
+ class DataArguments:
62
+ dataset_name: str = field(
63
+ default=None, metadata={"help": "huggingface dataset name"}
64
+ )
65
+ subset_name: List[str] = field(
66
+ default=None, metadata={"help": "Useful for datasets with subsets"}
67
+ )
68
+ dataset_split: str = field(
69
+ default='train', metadata={"help": "dataset split"}
70
+ )
71
+ num_sample_per_subset: int = field(
72
+ default=100, metadata={"help": "number of training samples per subset"}
73
+ )
74
+ image_dir: str = field(
75
+ default=None, metadata={"help": "Image directory path"}
76
+ )
77
+ encode_output_path: str = field(
78
+ default=None, metadata={"help": "encode output path"}
79
+ )
80
+ max_len: int = field(
81
+ default=128, metadata={"help": "The maximum total input sequence length after tokenization."},
82
+ )
83
+ embedding_type: str = field(
84
+ default="", metadata={"help": "embedding type"}
85
+ )
86
+
87
+
88
+ @dataclass
89
+ class TrainingArguments(TrainingArguments):
90
+ image_encoder_freeze: bool = field(
91
+ default=False, metadata={"help": "huggingface model name"}
92
+ )
93
+ output_dir: str = field(
94
+ default=None, metadata={"help": "directory for saving trained models"}
95
+ )
96
+ project_name: str = field(
97
+ default=None, metadata={"help": "project name"}
98
+ )
99
+
100
+ logging_steps: int = field(
101
+ default=1, metadata={"help": "logging steps"}
102
+ )
103
+ num_train_epochs: int = field(
104
+ default=1, metadata={"help": "number of training epochs"}
105
+ )
106
+ grad_cache: bool = field(
107
+ default=False, metadata={"help": "Use gradient cache update"})
108
+ gc_q_chunk_size: int = field(
109
+ default=2, metadata={"help": "query side subset size"})
110
+ gc_p_chunk_size: int = field(
111
+ default=2, metadata={"help": "target side subset size"})
112
+
113
+
114
+ @dataclass
115
+ class MTEBArguments:
116
+ task_types: List[str] = field(
117
+ default=None, metadata={"help": ""}
118
+ )
119
+ tasks: List[str] = field(
120
+ default=None, metadata={"help": ""}
121
+ )