machuofan commited on
Commit
7385f22
·
1 Parent(s): bf8190d
Files changed (50) hide show
  1. .gitattributes +0 -35
  2. constants.py +27 -0
  3. conversation.py +460 -0
  4. i2t.py +192 -0
  5. mm_utils.py +105 -0
  6. model/__init__.py +7 -0
  7. model/__pycache__/__init__.cpython-311.pyc +0 -0
  8. model/__pycache__/__init__.cpython-39.pyc +0 -0
  9. model/__pycache__/arhead.cpython-311.pyc +0 -0
  10. model/__pycache__/arhead.cpython-39.pyc +0 -0
  11. model/__pycache__/builder.cpython-311.pyc +0 -0
  12. model/__pycache__/builder.cpython-39.pyc +0 -0
  13. model/__pycache__/liquid.cpython-311.pyc +0 -0
  14. model/__pycache__/mini_gemini_arch.cpython-311.pyc +0 -0
  15. model/__pycache__/mini_gemini_arch.cpython-39.pyc +0 -0
  16. model/__pycache__/quant.cpython-311.pyc +0 -0
  17. model/__pycache__/quant.cpython-39.pyc +0 -0
  18. model/arhead.py +241 -0
  19. model/builder.py +138 -0
  20. model/language_model/__pycache__/mini_gemini_llama.cpython-311.pyc +0 -0
  21. model/language_model/__pycache__/mini_gemini_llama.cpython-39.pyc +0 -0
  22. model/language_model/mini_gemini_llama.py +488 -0
  23. model/liquid.py +669 -0
  24. model/multimodal_encoder/__pycache__/builder.cpython-311.pyc +0 -0
  25. model/multimodal_encoder/__pycache__/builder.cpython-39.pyc +0 -0
  26. model/multimodal_encoder/__pycache__/clip_encoder.cpython-311.pyc +0 -0
  27. model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc +0 -0
  28. model/multimodal_encoder/__pycache__/eva_encoder.cpython-311.pyc +0 -0
  29. model/multimodal_encoder/__pycache__/eva_encoder.cpython-39.pyc +0 -0
  30. model/multimodal_encoder/__pycache__/openclip_encoder.cpython-311.pyc +0 -0
  31. model/multimodal_encoder/__pycache__/openclip_encoder.cpython-39.pyc +0 -0
  32. model/multimodal_encoder/builder.py +33 -0
  33. model/multimodal_encoder/clip_encoder.py +89 -0
  34. model/multimodal_encoder/eva_encoder.py +551 -0
  35. model/multimodal_encoder/openclip_encoder.py +188 -0
  36. model/multimodal_projector/__pycache__/builder.cpython-311.pyc +0 -0
  37. model/multimodal_projector/__pycache__/builder.cpython-39.pyc +0 -0
  38. model/multimodal_projector/builder.py +50 -0
  39. model/processor/__pycache__/video_processor.cpython-311.pyc +0 -0
  40. model/processor/__pycache__/video_processor.cpython-39.pyc +0 -0
  41. model/processor/video_processor.py +74 -0
  42. model/quant.py +519 -0
  43. t2i.py +224 -0
  44. tools.py +126 -0
  45. unitok/config.py +243 -0
  46. unitok/dist.py +302 -0
  47. unitok/model.py +184 -0
  48. unitok/quant.py +185 -0
  49. unitok/vitamin.py +792 -0
  50. unitok/vqvae.py +175 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
constants.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ PREDICT_TOKEN_INDEX = -300
10
+ DEFAULT_IMAGE_TOKEN = "<image>"
11
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
12
+ DEFAULT_IM_START_TOKEN = "<im_start>"
13
+ DEFAULT_IM_END_TOKEN = "<im_end>"
14
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
15
+ DEFAULT_PREDICT_TOKEN = "<predict>"
16
+
17
+ DESCRIPT_PROMPT = [
18
+ "Describe this image thoroughly.",
19
+ "Provide a detailed description in this picture.",
20
+ "Detail every aspect of what's in this picture.",
21
+ "Explain this image with precision and detail.",
22
+ "Give a comprehensive description of this visual.",
23
+ "Elaborate on the specifics within this image.",
24
+ "Offer a detailed account of this picture's contents.",
25
+ "Describe in detail what this image portrays.",
26
+ "Break down this image into detailed descriptions.",
27
+ "Provide a thorough description of the elements in this image."]
conversation.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+ GEMMA = auto()
17
+
18
+
19
+ @dataclasses.dataclass
20
+ class Conversation:
21
+ """A class that keeps all conversation history."""
22
+ system: str
23
+ roles: List[str]
24
+ messages: List[List[str]]
25
+ offset: int
26
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
27
+ sep: str = "###"
28
+ sep2: str = None
29
+ version: str = "Unknown"
30
+
31
+ skip_next: bool = False
32
+
33
+ def get_prompt(self):
34
+ messages = self.messages
35
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
36
+ messages = self.messages.copy()
37
+ init_role, init_msg = messages[0].copy()
38
+ init_msg = init_msg[0].replace("<image>", "").strip()
39
+ if 'mmtag' in self.version:
40
+ messages[0] = (init_role, init_msg)
41
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
42
+ messages.insert(1, (self.roles[1], "Received."))
43
+ else:
44
+ messages[0] = (init_role, "<image>\n" + init_msg)
45
+
46
+ if self.sep_style == SeparatorStyle.SINGLE:
47
+ ret = self.system + self.sep
48
+ for role, message in messages:
49
+ if message:
50
+ if type(message) is tuple:
51
+ message = message[0]
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(messages):
59
+ if message:
60
+ if type(message) is tuple:
61
+ message = message[0]
62
+ ret += role + ": " + message + seps[i % 2]
63
+ else:
64
+ ret += role + ":"
65
+ elif self.sep_style == SeparatorStyle.MPT:
66
+ ret = self.system + self.sep
67
+ for role, message in messages:
68
+ if message:
69
+ if type(message) is tuple:
70
+ message = message[0]
71
+ ret += role + message + self.sep
72
+ else:
73
+ ret += role
74
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
75
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
76
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
77
+ ret = ""
78
+
79
+ for i, (role, message) in enumerate(messages):
80
+ if i == 0:
81
+ assert message, "first message should not be none"
82
+ assert role == self.roles[0], "first message should come from user"
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ if i == 0: message = wrap_sys(self.system) + message
87
+ if i % 2 == 0:
88
+ message = wrap_inst(message)
89
+ ret += self.sep + message
90
+ else:
91
+ ret += " " + message + " " + self.sep2
92
+ else:
93
+ ret += ""
94
+ ret = ret.lstrip(self.sep)
95
+ elif self.sep_style == SeparatorStyle.GEMMA:
96
+ seps = [self.sep, self.sep2]
97
+ ret = self.system + seps[0]
98
+ for i, (role, message) in enumerate(messages):
99
+ if message:
100
+ if type(message) is tuple:
101
+ message, _, _ = message
102
+ ret += "<start_of_turn>" + role + "\n" + message + "<end_of_turn>\n" + seps[i % 2]
103
+ else:
104
+ ret += "<start_of_turn>" + role + "\n"
105
+ elif self.sep_style == SeparatorStyle.PLAIN:
106
+ seps = [self.sep, self.sep2]
107
+ ret = self.system
108
+ for i, (role, message) in enumerate(messages):
109
+ if message:
110
+ if type(message) is tuple:
111
+ message, _, _ = message
112
+ ret += message + seps[i % 2]
113
+ else:
114
+ ret += ""
115
+ else:
116
+ raise ValueError(f"Invalid style: {self.sep_style}")
117
+
118
+ return ret
119
+
120
+ def append_message(self, role, message):
121
+ self.messages.append([role, message])
122
+
123
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
124
+ if image_process_mode == "Pad":
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(pil_img.mode, (width, width), background_color)
131
+ result.paste(pil_img, (0, (width - height) // 2))
132
+ return result
133
+ else:
134
+ result = Image.new(pil_img.mode, (height, height), background_color)
135
+ result.paste(pil_img, ((height - width) // 2, 0))
136
+ return result
137
+ image = expand2square(image)
138
+ elif image_process_mode in ["Default", "Crop"]:
139
+ pass
140
+ elif image_process_mode == "Resize":
141
+ image = image.resize((336, 336))
142
+ else:
143
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
144
+ if max(image.size) > max_len:
145
+ max_hw, min_hw = max(image.size), min(image.size)
146
+ aspect_ratio = max_hw / min_hw
147
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
148
+ longest_edge = int(shortest_edge * aspect_ratio)
149
+ W, H = image.size
150
+ if H > W:
151
+ H, W = longest_edge, shortest_edge
152
+ else:
153
+ H, W = shortest_edge, longest_edge
154
+ image = image.resize((W, H))
155
+ if return_pil:
156
+ return image
157
+ else:
158
+ buffered = BytesIO()
159
+ image.save(buffered, format=image_format)
160
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
161
+ return img_b64_str
162
+
163
+ def get_images(self, return_pil=False):
164
+ images = []
165
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
166
+ if i % 2 == 0:
167
+ if type(msg) is tuple:
168
+ msg, image, image_process_mode = msg
169
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
170
+ images.append(image)
171
+ return images
172
+
173
+ def to_gradio_chatbot(self):
174
+ ret = []
175
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
176
+ if i % 2 == 0:
177
+ if type(msg) is tuple:
178
+ msg, image, image_process_mode = msg
179
+ img_b64_str = self.process_image(
180
+ image, "Default", return_pil=False,
181
+ image_format='JPEG')
182
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
183
+ msg = img_str + msg.replace('<image>', '').strip()
184
+ ret.append([msg, None])
185
+ else:
186
+ ret.append([msg, None])
187
+ else:
188
+ if type(msg) is tuple and len(msg) == 2:
189
+ msg, img_b64_str = msg
190
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
191
+ msg = msg.strip() + img_str
192
+ ret[-1][-1] = msg
193
+ return ret
194
+
195
+ def copy(self):
196
+ return Conversation(
197
+ system=self.system,
198
+ roles=self.roles,
199
+ messages=[[x, y] for x, y in self.messages],
200
+ offset=self.offset,
201
+ sep_style=self.sep_style,
202
+ sep=self.sep,
203
+ sep2=self.sep2,
204
+ version=self.version)
205
+
206
+ def dict(self):
207
+ if len(self.get_images()) > 0:
208
+ return {
209
+ "system": self.system,
210
+ "roles": self.roles,
211
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
212
+ "offset": self.offset,
213
+ "sep": self.sep,
214
+ "sep2": self.sep2,
215
+ }
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": self.messages,
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+
225
+
226
+ conv_vicuna_v0 = Conversation(
227
+ system="A chat between a curious human and an artificial intelligence assistant. "
228
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
229
+ roles=("Human", "Assistant"),
230
+ messages=(
231
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
232
+ ("Assistant",
233
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
234
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
235
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
236
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
237
+ "renewable and non-renewable energy sources:\n"
238
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
239
+ "energy sources are finite and will eventually run out.\n"
240
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
241
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
242
+ "and other negative effects.\n"
243
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
244
+ "have lower operational costs than non-renewable sources.\n"
245
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
246
+ "locations than non-renewable sources.\n"
247
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
248
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
249
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
250
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
251
+ ),
252
+ offset=2,
253
+ sep_style=SeparatorStyle.SINGLE,
254
+ sep="###",
255
+ )
256
+
257
+ conv_vicuna_v1 = Conversation(
258
+ system="A chat between a curious user and an artificial intelligence assistant. "
259
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
260
+ roles=("USER", "ASSISTANT"),
261
+ version="v1",
262
+ messages=(),
263
+ offset=0,
264
+ sep_style=SeparatorStyle.TWO,
265
+ sep=" ",
266
+ sep2="</s>",
267
+ )
268
+
269
+ conv_llama_2 = Conversation(
270
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
271
+
272
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
273
+ roles=("USER", "ASSISTANT"),
274
+ version="llama_v2",
275
+ messages=(),
276
+ offset=0,
277
+ sep_style=SeparatorStyle.LLAMA_2,
278
+ sep="<s>",
279
+ sep2="</s>",
280
+ )
281
+
282
+ conv_llava_llama_2 = Conversation(
283
+ system="You are a helpful language and vision assistant. "
284
+ "You are able to understand the visual content that the user provides, "
285
+ "and assist the user with a variety of tasks using natural language.",
286
+ roles=("USER", "ASSISTANT"),
287
+ version="llama_v2",
288
+ messages=(),
289
+ offset=0,
290
+ sep_style=SeparatorStyle.LLAMA_2,
291
+ sep="<s>",
292
+ sep2="</s>",
293
+ )
294
+
295
+ conv_mpt = Conversation(
296
+ system="""<|im_start|>system
297
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
298
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
299
+ version="mpt",
300
+ messages=(),
301
+ offset=0,
302
+ sep_style=SeparatorStyle.MPT,
303
+ sep="<|im_end|>",
304
+ )
305
+
306
+ conv_llava_plain = Conversation(
307
+ system="",
308
+ roles=("", ""),
309
+ messages=(
310
+ ),
311
+ offset=0,
312
+ sep_style=SeparatorStyle.PLAIN,
313
+ sep="\n",
314
+ )
315
+
316
+ conv_llava_v0 = Conversation(
317
+ system="A chat between a curious human and an artificial intelligence assistant. "
318
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
319
+ roles=("Human", "Assistant"),
320
+ messages=(
321
+ ),
322
+ offset=0,
323
+ sep_style=SeparatorStyle.SINGLE,
324
+ sep="###",
325
+ )
326
+
327
+ conv_llava_v0_mmtag = Conversation(
328
+ system="A chat between a curious user and an artificial intelligence assistant. "
329
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
330
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
331
+ roles=("Human", "Assistant"),
332
+ messages=(
333
+ ),
334
+ offset=0,
335
+ sep_style=SeparatorStyle.SINGLE,
336
+ sep="###",
337
+ version="v0_mmtag",
338
+ )
339
+
340
+ conv_llava_v1 = Conversation(
341
+ system="A chat between a curious human and an artificial intelligence assistant. "
342
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
343
+ roles=("USER", "ASSISTANT"),
344
+ version="v1",
345
+ messages=(),
346
+ offset=0,
347
+ sep_style=SeparatorStyle.TWO,
348
+ sep=" ",
349
+ sep2="</s>",
350
+ )
351
+
352
+ conv_vicuna_imgsp_v1 = Conversation(
353
+ system="A chat between a curious user and an artificial intelligence assistant. "
354
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
355
+ roles=("USER", "ASSISTANT"),
356
+ version="imgsp_v1",
357
+ messages=(),
358
+ offset=0,
359
+ sep_style=SeparatorStyle.TWO,
360
+ sep=" ",
361
+ sep2="</s>",
362
+ )
363
+
364
+ conv_llava_plain_guided = Conversation(
365
+ system="",
366
+ roles=("", ""),
367
+ version="plain_guided",
368
+ messages=(
369
+ ),
370
+ offset=0,
371
+ sep_style=SeparatorStyle.PLAIN,
372
+ sep="\n",
373
+ )
374
+
375
+ conv_llava_v1_mmtag = Conversation(
376
+ system="A chat between a curious user and an artificial intelligence assistant. "
377
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
378
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
379
+ roles=("USER", "ASSISTANT"),
380
+ messages=(),
381
+ offset=0,
382
+ sep_style=SeparatorStyle.TWO,
383
+ sep=" ",
384
+ sep2="</s>",
385
+ version="v1_mmtag",
386
+ )
387
+
388
+ conv_phi_2 = Conversation(
389
+ system="A chat between a curious user and an artificial intelligence assistant. "
390
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
391
+ roles=("USER", "ASSISTANT"),
392
+ version="phi2",
393
+ messages=(),
394
+ offset=0,
395
+ sep_style=SeparatorStyle.TWO,
396
+ sep=" ",
397
+ sep2="<|endoftext|>",
398
+ )
399
+
400
+ conv_mistral_instruct = Conversation(
401
+ system="",
402
+ roles=("USER", "ASSISTANT"),
403
+ version="llama_v2",
404
+ messages=(),
405
+ offset=0,
406
+ sep_style=SeparatorStyle.LLAMA_2,
407
+ sep="<s>",
408
+ sep2="</s>",
409
+ )
410
+
411
+ conv_gemma = Conversation(
412
+ system="",
413
+ roles=("user", "model"),
414
+ version="gemma",
415
+ messages=(),
416
+ offset=0,
417
+ sep_style=SeparatorStyle.GEMMA,
418
+ sep="",
419
+ sep2="<eos>",
420
+ )
421
+
422
+ conv_chatml_direct = Conversation(
423
+ system="""<|im_start|>system
424
+ Answer the questions.""",
425
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
426
+ version="mpt",
427
+ messages=(),
428
+ offset=0,
429
+ sep_style=SeparatorStyle.MPT,
430
+ sep="<|im_end|>",
431
+ )
432
+
433
+ default_conversation = conv_vicuna_v1
434
+ conv_templates = {
435
+ "default": conv_vicuna_v0,
436
+ "v0": conv_vicuna_v0,
437
+ "v1": conv_vicuna_v1,
438
+ "vicuna_v1": conv_vicuna_v1,
439
+ "phi_2": conv_phi_2,
440
+ "gemma": conv_gemma,
441
+ "llama_2": conv_llama_2,
442
+ "imgsp_v1": conv_vicuna_imgsp_v1,
443
+ "plain_guided": conv_llava_plain_guided,
444
+ "mistral_instruct": conv_mistral_instruct,
445
+ "chatml_direct": conv_chatml_direct,
446
+ "mistral_direct": conv_chatml_direct,
447
+ "plain": conv_llava_plain,
448
+ "v0_plain": conv_llava_plain,
449
+ "llava_v0": conv_llava_v0,
450
+ "v0_mmtag": conv_llava_v0_mmtag,
451
+ "llava_v1": conv_llava_v1,
452
+ "v1_mmtag": conv_llava_v1_mmtag,
453
+ "llava_llama_2": conv_llava_llama_2,
454
+
455
+ "mpt": conv_mpt,
456
+ }
457
+
458
+
459
+ if __name__ == "__main__":
460
+ print(default_conversation.get_prompt())
i2t.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import math
5
+ import torch
6
+ import argparse
7
+ import shortuuid
8
+ from tqdm import tqdm
9
+ from PIL import Image
10
+ from PIL import ImageFile
11
+ from torchvision import transforms
12
+
13
+ from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
14
+ from conversation import conv_templates, SeparatorStyle
15
+ from model.builder import load_pretrained_model
16
+ from tools import disable_torch_init
17
+ from mm_utils import tokenizer_image_token, get_model_name_from_path
18
+ from torch.utils.data import Dataset, DataLoader
19
+
20
+
21
+ from unitok.config import Args
22
+ from unitok.model import UniTok
23
+
24
+ ImageFile.LOAD_TRUNCATED_IMAGES = False
25
+ torch.set_grad_enabled(False)
26
+
27
+
28
+ def split_list(lst, n):
29
+ """Split a list into n (roughly) equal-sized chunks"""
30
+ chunk_size = math.ceil(len(lst) / n) # integer division
31
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
32
+
33
+
34
+ def get_chunk(lst, n, k):
35
+ chunks = split_list(lst, n)
36
+ return chunks[k]
37
+
38
+
39
+ # Custom dataset class
40
+ class CustomDataset(Dataset):
41
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
42
+ self.questions = questions
43
+ self.image_folder = image_folder
44
+ self.tokenizer = tokenizer
45
+ self.image_processor = image_processor
46
+ self.model_config = model_config
47
+
48
+ def __getitem__(self, index):
49
+ line = self.questions[index]
50
+ image_file = line["image"]
51
+ qs = line["text"]
52
+
53
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
54
+
55
+ conv = conv_templates[args.conv_mode].copy()
56
+ conv.append_message(conv.roles[0], qs)
57
+ conv.append_message(conv.roles[1], None)
58
+ prompt = conv.get_prompt()
59
+ # prompt = prompt.replace('<image>','<boi><image><eoi>')
60
+ # import pdb;pdb.set_trace()
61
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
62
+ # import pdb;pdb.set_trace()
63
+ pad_image = expand2square(image, (122, 116, 104) )
64
+ # import pdb;pdb.set_trace()
65
+ img = self.image_processor[0](pad_image).unsqueeze(0)
66
+ img = img.to('cuda')
67
+ # import pdb;pdb.set_trace()
68
+ with torch.no_grad():
69
+ vq_code = self.image_processor[1].img_to_idx(img)
70
+ vqcode = vq_code.cpu()
71
+
72
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
73
+
74
+
75
+ return input_ids,vqcode,os.path.join(self.image_folder, image_file) #, image_tensor, image_tensor_aux
76
+
77
+ def __len__(self):
78
+ return len(self.questions)
79
+
80
+
81
+ # DataLoader
82
+ def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=0):
83
+ assert batch_size == 1, "batch_size must be 1"
84
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
85
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
86
+ return data_loader
87
+
88
+ def expand2square(pil_img, background_color):
89
+ width, height = pil_img.size
90
+ if width == height:
91
+ return pil_img
92
+ elif width > height:
93
+ result = Image.new(pil_img.mode, (width, width), background_color)
94
+ result.paste(pil_img, (0, (width - height) // 2))
95
+ return result
96
+ else:
97
+ result = Image.new(pil_img.mode, (height, height), background_color)
98
+ result.paste(pil_img, ((height - width) // 2, 0))
99
+ return result
100
+
101
+ def eval_model(args):
102
+ # Model
103
+ disable_torch_init()
104
+ model_path = os.path.expanduser(args.model_path)
105
+ model_name = get_model_name_from_path(model_path)
106
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, load_8bit=args.load_8bit)
107
+
108
+ ckpt = torch.load(args.tokenizer_path, map_location='cpu')
109
+ vae_cfg = Args()
110
+ vae_cfg.load_state_dict(ckpt['args'])
111
+ vq_model = UniTok(vae_cfg)
112
+ vq_model.load_state_dict(ckpt['trainer']['unitok'])
113
+ vq_model.to('cuda')
114
+ vq_model.eval()
115
+ del ckpt
116
+
117
+ crop_size = 256
118
+ transform = transforms.Compose([
119
+ transforms.Resize((crop_size, crop_size)),
120
+ transforms.ToTensor(),
121
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
122
+ ])
123
+ image_processor = (transform, vq_model)
124
+
125
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
126
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
127
+ answers_file = os.path.expanduser(args.answers_file)
128
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
129
+ ans_file = open(answers_file, "w")
130
+
131
+ if 'plain' in args.conv_mode and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
132
+ args.conv_mode = args.conv_mode + '_mmtag'
133
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
134
+
135
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
136
+
137
+ for (input_ids, image_codes,imagepath), line in tqdm(zip(data_loader, questions), total=len(questions)):
138
+ idx = line["question_id"]
139
+ cur_prompt = line["text"]
140
+
141
+ input_ids = input_ids.to(device=model.device, non_blocking=True)
142
+ image_codes = image_codes.to(device=model.device, non_blocking=True)
143
+ if hasattr(model, "update_prompt"):
144
+ model.update_prompt([[cur_prompt]])
145
+
146
+ with torch.inference_mode():
147
+ output_ids = model.generate_mllm(
148
+ input_ids,
149
+ images=image_codes,
150
+ images_aux= None,
151
+ do_sample=True if args.temperature > 0 else False,
152
+ temperature=args.temperature,
153
+ top_p=args.top_p,
154
+ num_beams=args.num_beams,
155
+ max_new_tokens=args.max_new_tokens,
156
+ bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
157
+ eos_token_id=tokenizer.eos_token_id, # End of sequence token
158
+ pad_token_id=tokenizer.pad_token_id, # Pad token
159
+ use_cache=False
160
+ )
161
+
162
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
163
+ ans_id = shortuuid.uuid()
164
+ ans_file.write(json.dumps({
165
+ "question_id": idx,
166
+ "prompt": cur_prompt,
167
+ "text": outputs,
168
+ "answer_id": ans_id,
169
+ "model_id": model_name,
170
+ "metadata": {}
171
+ }) + "\n")
172
+ ans_file.close()
173
+
174
+ if __name__ == "__main__":
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
177
+ parser.add_argument("--tokenizer-path", type=str, required=True)
178
+ parser.add_argument("--model-base", type=str, default=None)
179
+ parser.add_argument("--image-folder", type=str, default="")
180
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
181
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
182
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
183
+ parser.add_argument("--num-chunks", type=int, default=1)
184
+ parser.add_argument("--chunk-idx", type=int, default=0)
185
+ parser.add_argument("--temperature", type=float, default=0.2)
186
+ parser.add_argument("--top_p", type=float, default=None)
187
+ parser.add_argument("--num_beams", type=int, default=1)
188
+ parser.add_argument('--load_8bit', type=bool, default=False)
189
+ parser.add_argument("--max_new_tokens", type=int, default=128)
190
+ args = parser.parse_args()
191
+
192
+ eval_model(args)
mm_utils.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from constants import IMAGE_TOKEN_INDEX
8
+
9
+
10
+ def load_image_from_base64(image):
11
+ return Image.open(BytesIO(base64.b64decode(image)))
12
+
13
+
14
+ def expand2square(pil_img, background_color):
15
+ width, height = pil_img.size
16
+ if width == height:
17
+ return pil_img
18
+ elif width > height:
19
+ result = Image.new(pil_img.mode, (width, width), background_color)
20
+ result.paste(pil_img, (0, (width - height) // 2))
21
+ return result
22
+ else:
23
+ result = Image.new(pil_img.mode, (height, height), background_color)
24
+ result.paste(pil_img, ((height - width) // 2, 0))
25
+ return result
26
+
27
+
28
+ def process_images(images, image_processor, model_cfg):
29
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30
+ new_images = []
31
+ if image_aspect_ratio == 'pad':
32
+ for image in images:
33
+ image = expand2square(image.convert('RGB'), tuple(int(x*255) for x in image_processor.image_mean))
34
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35
+ new_images.append(image)
36
+ else:
37
+ return image_processor(images, return_tensors='pt')['pixel_values']
38
+ if all(x.shape == new_images[0].shape for x in new_images):
39
+ new_images = torch.stack(new_images, dim=0)
40
+ return new_images
41
+
42
+
43
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
45
+
46
+ def insert_separator(X, sep):
47
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
48
+
49
+ input_ids = []
50
+ offset = 0
51
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
52
+ offset = 1
53
+ input_ids.append(prompt_chunks[0][0])
54
+
55
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
56
+ input_ids.extend(x[offset:])
57
+
58
+ if return_tensors is not None:
59
+ if return_tensors == 'pt':
60
+ return torch.tensor(input_ids, dtype=torch.long)
61
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
62
+ return input_ids
63
+
64
+
65
+ def get_model_name_from_path(model_path):
66
+ model_path = model_path.strip("/")
67
+ model_paths = model_path.split("/")
68
+ if model_paths[-1].startswith('checkpoint-'):
69
+ return model_paths[-2] + "_" + model_paths[-1]
70
+ else:
71
+ return model_paths[-1]
72
+
73
+ class KeywordsStoppingCriteria(StoppingCriteria):
74
+ def __init__(self, keywords, tokenizer, input_ids):
75
+ self.keywords = keywords
76
+ self.keyword_ids = []
77
+ self.max_keyword_len = 0
78
+ for keyword in keywords:
79
+ cur_keyword_ids = tokenizer(keyword).input_ids
80
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
81
+ cur_keyword_ids = cur_keyword_ids[1:]
82
+ if len(cur_keyword_ids) > self.max_keyword_len:
83
+ self.max_keyword_len = len(cur_keyword_ids)
84
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
85
+ self.tokenizer = tokenizer
86
+ self.start_len = input_ids.shape[1]
87
+
88
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
89
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
90
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
91
+ for keyword_id in self.keyword_ids:
92
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
93
+ if torch.equal(truncated_output_ids, keyword_id):
94
+ return True
95
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
96
+ for keyword in self.keywords:
97
+ if keyword in outputs:
98
+ return True
99
+ return False
100
+
101
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
+ outputs = []
103
+ for i in range(output_ids.shape[0]):
104
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
105
+ return all(outputs)
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .language_model.mini_gemini_llama import MiniGeminiLlamaForCausalLM
2
+ try:
3
+ from .language_model.mini_gemini_mistral import MiniGeminiMistralForCausalLM
4
+ from .language_model.mini_gemini_mixtral import MiniGeminiMixtralForCausalLM
5
+ from .language_model.mini_gemini_gemma import MiniGeminiGemmaForCausalLM
6
+ except:
7
+ ImportWarning("New model not imported. Try to update Transformers.")
model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (714 Bytes). View file
 
model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (585 Bytes). View file
 
model/__pycache__/arhead.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
model/__pycache__/arhead.cpython-39.pyc ADDED
Binary file (5.89 kB). View file
 
model/__pycache__/builder.cpython-311.pyc ADDED
Binary file (5.87 kB). View file
 
model/__pycache__/builder.cpython-39.pyc ADDED
Binary file (2.91 kB). View file
 
model/__pycache__/liquid.cpython-311.pyc ADDED
Binary file (43.8 kB). View file
 
model/__pycache__/mini_gemini_arch.cpython-311.pyc ADDED
Binary file (49.9 kB). View file
 
model/__pycache__/mini_gemini_arch.cpython-39.pyc ADDED
Binary file (20.5 kB). View file
 
model/__pycache__/quant.cpython-311.pyc ADDED
Binary file (38.9 kB). View file
 
model/__pycache__/quant.cpython-39.pyc ADDED
Binary file (16.5 kB). View file
 
model/arhead.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.checkpoint
3
+ from torch import nn
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
7
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaDecoderLayer
8
+ from transformers.modeling_outputs import BaseModelOutputWithPast
9
+
10
+ class AR_head(nn.Module):
11
+ """
12
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
13
+
14
+ Args:
15
+ config: GemmaConfig
16
+ """
17
+
18
+ def __init__(self, config, codebook_size, num_codebooks):
19
+ super().__init__()
20
+ # import pdb;pdb.set_trace()
21
+ self.num_codebooks = num_codebooks
22
+ vocab_size = codebook_size
23
+ self.sub_vocab_size = vocab_size // self.num_codebooks
24
+
25
+ # self.layers = nn.ModuleList(
26
+ # [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(3)]
27
+ # )
28
+ # self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
29
+ self.linear_head = nn.Linear(config.hidden_size, self.sub_vocab_size)
30
+
31
+ self.layers = nn.ModuleList(
32
+ [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(3)]
33
+ )
34
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
35
+ self.gradient_checkpointing = False
36
+
37
+
38
+
39
+
40
+ # vocab_size 16384
41
+ self.codebooks = nn.ModuleList()
42
+ for _ in range(self.num_codebooks-1):
43
+ codebook = nn.Embedding(self.sub_vocab_size, config.hidden_size)
44
+ self.codebooks.append(codebook)
45
+ # import pdb;pdb.set_trace()
46
+ self.config = config
47
+ self.gradient_checkpointing = False
48
+
49
+ # Initialize weights and apply final processing
50
+ self._init_weights(self.layers)
51
+
52
+ def set_input_embeddings(self, value):
53
+ self.embed_tokens = value
54
+
55
+ def _init_weights(self, module):
56
+ std = self.config.initializer_range
57
+ if isinstance(module, nn.Linear):
58
+ module.weight.data.normal_(mean=0.0, std=std)
59
+ if module.bias is not None:
60
+ module.bias.data.zero_()
61
+ elif isinstance(module, nn.Embedding):
62
+ module.weight.data.normal_(mean=0.0, std=std)
63
+ if module.padding_idx is not None:
64
+ module.weight.data[module.padding_idx].zero_()
65
+
66
+ # Ignore copy
67
+ def forward(
68
+ self,
69
+ input_ids: torch.LongTensor = None,
70
+ attention_mask: Optional[torch.Tensor] = None,
71
+ position_ids: Optional[torch.LongTensor] = None,
72
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
73
+ inputs_embeds: Optional[torch.FloatTensor] = None,
74
+ use_cache: Optional[bool] = None,
75
+ output_attentions: Optional[bool] = None,
76
+ output_hidden_states: Optional[bool] = None,
77
+ return_dict: Optional[bool] = None,
78
+ cache_position: Optional[torch.LongTensor] = None,
79
+ ) -> torch.tensor:
80
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
81
+ output_hidden_states = (
82
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
83
+ )
84
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
85
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
+
87
+ if (input_ids is None) ^ (inputs_embeds is not None):
88
+ raise ValueError(
89
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
90
+ )
91
+
92
+ if self.gradient_checkpointing and self.training and use_cache:
93
+ logger.warning_once(
94
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
95
+ )
96
+ use_cache = False
97
+
98
+ if inputs_embeds is None:
99
+ inputs_embeds = self.embed_tokens(input_ids)
100
+
101
+ past_seen_tokens = 0
102
+ if use_cache: # kept for BC (cache positions)
103
+ if not isinstance(past_key_values, StaticCache):
104
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
105
+ past_seen_tokens = past_key_values.get_seq_length()
106
+
107
+ if cache_position is None:
108
+ if isinstance(past_key_values, StaticCache):
109
+ raise ValueError("cache_position is a required argument when using StaticCache.")
110
+ cache_position = torch.arange(
111
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
112
+ )
113
+
114
+ if position_ids is None:
115
+ position_ids = cache_position.unsqueeze(0)
116
+
117
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
118
+
119
+ # embed positions
120
+ hidden_states = inputs_embeds
121
+
122
+ # decoder layers
123
+ all_hidden_states = () if output_hidden_states else None
124
+ all_self_attns = () if output_attentions else None
125
+ next_decoder_cache = None
126
+
127
+ for decoder_layer in self.layers:
128
+ if output_hidden_states:
129
+ all_hidden_states += (hidden_states,)
130
+
131
+ if self.gradient_checkpointing and self.training:
132
+ layer_outputs = self._gradient_checkpointing_func(
133
+ decoder_layer.__call__,
134
+ hidden_states,
135
+ causal_mask,
136
+ position_ids,
137
+ past_key_values,
138
+ output_attentions,
139
+ use_cache,
140
+ cache_position,
141
+ )
142
+ else:
143
+ layer_outputs = decoder_layer(
144
+ hidden_states,
145
+ attention_mask=causal_mask,
146
+ position_ids=position_ids,
147
+ past_key_value=past_key_values,
148
+ output_attentions=output_attentions,
149
+ use_cache=use_cache,
150
+ cache_position=cache_position,
151
+ )
152
+
153
+ hidden_states = layer_outputs[0]
154
+
155
+ if use_cache:
156
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
157
+
158
+ if output_attentions:
159
+ all_self_attns += (layer_outputs[1],)
160
+
161
+ hidden_states = self.norm(hidden_states)
162
+
163
+ if output_hidden_states:
164
+ all_hidden_states += (hidden_states,)
165
+
166
+ next_cache = None
167
+ if use_cache:
168
+ next_cache = (
169
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
170
+ )
171
+ if not return_dict:
172
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
173
+ return BaseModelOutputWithPast(
174
+ last_hidden_state=hidden_states,
175
+ past_key_values=next_cache,
176
+ hidden_states=all_hidden_states,
177
+ attentions=all_self_attns,
178
+ )
179
+
180
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
181
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
182
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
183
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
184
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
185
+ if self.config._attn_implementation == "flash_attention_2":
186
+ if attention_mask is not None and 0.0 in attention_mask:
187
+ return attention_mask
188
+ return None
189
+
190
+ dtype, device = input_tensor.dtype, input_tensor.device
191
+ min_dtype = torch.finfo(dtype).min
192
+ sequence_length = input_tensor.shape[1]
193
+ if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
194
+ target_length = self.config.max_position_embeddings
195
+ else: # dynamic cache
196
+ target_length = (
197
+ attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
198
+ )
199
+
200
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
201
+ if sequence_length != 1:
202
+ causal_mask = torch.triu(causal_mask, diagonal=1)
203
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
204
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
205
+ if attention_mask is not None:
206
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
207
+ if attention_mask.dim() == 2:
208
+ mask_length = attention_mask.shape[-1]
209
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
210
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
211
+ elif attention_mask.dim() == 4:
212
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
213
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
214
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
215
+ offset = cache_position[0]
216
+ else:
217
+ offset = 0
218
+ mask_shape = attention_mask.shape
219
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
220
+ causal_mask[
221
+ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
222
+ ] = mask_slice
223
+
224
+ if (
225
+ self.config._attn_implementation == "sdpa"
226
+ and attention_mask is not None
227
+ and attention_mask.device.type == "cuda"
228
+ ):
229
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
230
+ is_tracing = (
231
+ torch.jit.is_tracing()
232
+ or isinstance(input_tensor, torch.fx.Proxy)
233
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
234
+ )
235
+ if not is_tracing and torch.any(attention_mask != 1):
236
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
237
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
238
+ # Details: https://github.com/pytorch/pytorch/issues/110213
239
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
240
+
241
+ return causal_mask
model/builder.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ import os
20
+ import torch
21
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
22
+
23
+ from model import *
24
+ from constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27
+ kwargs = {"device_map": device_map, **kwargs}
28
+
29
+ if device != "cuda":
30
+ kwargs['device_map'] = {"": device}
31
+
32
+ if load_8bit:
33
+ kwargs['load_in_8bit'] = True
34
+ elif load_4bit:
35
+ kwargs['load_in_4bit'] = True
36
+ kwargs['quantization_config'] = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_compute_dtype=torch.float16,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type='nf4'
41
+ )
42
+ else:
43
+ kwargs['torch_dtype'] = torch.float16
44
+
45
+ if use_flash_attn:
46
+ kwargs['attn_implementation'] = 'flash_attention_2'
47
+ # import pdb;pdb.set_trace()
48
+ if 'mini-gemini' in model_name.lower():
49
+ # Load MiniGemini model
50
+ if model_base is not None:
51
+ # this may be mm projector only
52
+ print('Loading MiniGemini from base model...')
53
+
54
+ if "8x7b" in model_name.lower():
55
+ tokenizer = AutoTokenizer.from_pretrained(model_base)
56
+ model = MiniGeminiMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
57
+ elif "2b" in model_name.lower():
58
+ tokenizer = AutoTokenizer.from_pretrained(model_base)
59
+ model = MiniGeminiGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
60
+ else:
61
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
62
+ model = MiniGeminiLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
63
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
64
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
65
+ model.load_state_dict(mm_projector_weights, strict=False)
66
+ else:
67
+ if "8x7b" in model_name.lower():
68
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
69
+ model = MiniGeminiMixtralForCausalLM.from_pretrained(model_path, **kwargs)
70
+ elif "2b" in model_name.lower():
71
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
72
+ model = MiniGeminiGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
73
+ else:
74
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
75
+ model = MiniGeminiLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
76
+
77
+ if 'gemma' in model_name.lower():
78
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
79
+ model = MiniGeminiGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
80
+ elif 'vicuna' in model_name.lower() or 'unitok' in model_name.lower():
81
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
82
+ model = MiniGeminiLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
83
+ else:
84
+ # Load language model
85
+ if model_base is not None:
86
+ # PEFT model
87
+ from peft import PeftModel
88
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
89
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
90
+ print(f"Loading LoRA weights from {model_path}")
91
+ model = PeftModel.from_pretrained(model, model_path)
92
+ print(f"Merging weights")
93
+ model = model.merge_and_unload()
94
+ print('Convert to FP16...')
95
+ model.to(torch.float16)
96
+ else:
97
+ if 'mpt' in model_name.lower():
98
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
99
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
100
+ else:
101
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
102
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
103
+
104
+ image_processor = None
105
+ # import pdb;pdb.set_trace()
106
+
107
+
108
+ # mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
109
+ # mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
110
+ # if mm_use_im_patch_token:
111
+ # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
112
+ # if mm_use_im_start_end:
113
+ # tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
114
+
115
+ # model.resize_token_embeddings(len(tokenizer))
116
+
117
+ # vision_tower = model.get_vision_tower()
118
+ # if not vision_tower.is_loaded:
119
+ # vision_tower.load_model()
120
+ # vision_tower.to(device=device, dtype=torch.float16)
121
+ # image_processor = vision_tower.image_processor
122
+
123
+ if 'mini-gemini' in model_name.lower():
124
+ vision_tower_aux = model.get_vision_tower_aux()
125
+ if not vision_tower_aux.is_loaded:
126
+ vision_tower_aux.load_model()
127
+ vision_tower_aux.to(device=device, dtype=torch.float16)
128
+
129
+ # initialize attention modules
130
+ model.config.model_path = model_path
131
+ model.get_model().initialize_uni_modules(model.config, for_eval=True)
132
+
133
+ if hasattr(model.config, "max_sequence_length"):
134
+ context_len = model.config.max_sequence_length
135
+ else:
136
+ context_len = 2048
137
+
138
+ return tokenizer, model, image_processor, context_len
model/language_model/__pycache__/mini_gemini_llama.cpython-311.pyc ADDED
Binary file (18 kB). View file
 
model/language_model/__pycache__/mini_gemini_llama.cpython-39.pyc ADDED
Binary file (7.73 kB). View file
 
model/language_model/mini_gemini_llama.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.nn import CrossEntropyLoss
23
+ from typing import List, Optional, Tuple, Union
24
+
25
+ from transformers.utils import logging
26
+ from transformers.generation.utils import GenerateOutput
27
+ from transformers.modeling_outputs import CausalLMOutputWithPast
28
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
29
+
30
+ from model.arhead import AR_head
31
+ from model.liquid import MiniGeminiMetaModel, MiniGeminiMetaForCausalLM
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ class MiniGeminiConfig(LlamaConfig):
38
+ model_type = "mini_gemini"
39
+
40
+
41
+ class MiniGeminiLlamaModel(MiniGeminiMetaModel, LlamaModel):
42
+ config_class = MiniGeminiConfig
43
+
44
+ def __init__(self, config: LlamaConfig):
45
+ super(MiniGeminiLlamaModel, self).__init__(config)
46
+
47
+
48
+ class MiniGeminiLlamaForCausalLM(LlamaForCausalLM, MiniGeminiMetaForCausalLM):
49
+ config_class = MiniGeminiConfig
50
+
51
+ def __init__(self, config):
52
+ super(LlamaForCausalLM, self).__init__(config)
53
+ self.model = MiniGeminiLlamaModel(config)
54
+ self.pretraining_tp = config.pretraining_tp
55
+ self.vocab_size = config.vocab_size
56
+
57
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
58
+ self.ar_head = AR_head(self.config, codebook_size=32768, num_codebooks=8)
59
+
60
+ # Initialize weights and apply final processing
61
+ self.post_init()
62
+
63
+ def get_model(self):
64
+ return self.model
65
+
66
+ def forward(
67
+ self,
68
+ input_ids: torch.LongTensor = None,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ position_ids: Optional[torch.LongTensor] = None,
71
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
72
+ inputs_embeds: Optional[torch.FloatTensor] = None,
73
+ labels: Optional[torch.LongTensor] = None,
74
+ data_types: torch.LongTensor = None,
75
+ use_cache: Optional[bool] = None,
76
+ cache_position: Optional[torch.LongTensor] = None,
77
+ output_attentions: Optional[bool] = None,
78
+ output_hidden_states: Optional[bool] = None,
79
+ images: Optional[torch.FloatTensor] = None,
80
+ images_aux: Optional[torch.FloatTensor] = None,
81
+ return_dict: Optional[bool] = None,
82
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
83
+
84
+
85
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
86
+ output_hidden_states = (
87
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
88
+ )
89
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
90
+
91
+ additional_image_indexs = None
92
+ if inputs_embeds is None and past_key_values is None: # no in inference
93
+ (
94
+ input_ids,
95
+ position_ids,
96
+ attention_mask,
97
+ past_key_values,
98
+ inputs_embeds,
99
+ labels,
100
+ data_types,
101
+ additional_image_labels,
102
+ additional_image_indexs
103
+ ) = self.prepare_inputs_labels_for_multimodal(
104
+ input_ids,
105
+ position_ids,
106
+ attention_mask,
107
+ past_key_values,
108
+ labels,
109
+ images,
110
+ images_aux,
111
+ data_types
112
+ )
113
+
114
+ outputs = self.model(
115
+ input_ids=input_ids,
116
+ attention_mask=attention_mask,
117
+ position_ids=position_ids,
118
+ past_key_values=past_key_values,
119
+ inputs_embeds=inputs_embeds,
120
+ use_cache=use_cache,
121
+ output_attentions=output_attentions,
122
+ output_hidden_states=output_hidden_states,
123
+ return_dict=return_dict,
124
+ )
125
+
126
+ hidden_states = outputs[0]
127
+
128
+ if self.pretraining_tp > 1:
129
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
130
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
131
+ logits = torch.cat(logits, dim=-1)
132
+ else:
133
+ logits = self.lm_head(hidden_states)
134
+ logits = logits.float()
135
+
136
+ text_loss = None
137
+ if labels is not None:
138
+ # Shift so that tokens < n predict n
139
+ shift_logits = logits[..., :-1, :].contiguous()
140
+ shift_labels = labels[..., 1:].contiguous()
141
+ # Flatten the tokens
142
+ loss_fct = CrossEntropyLoss()
143
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
144
+ shift_labels = shift_labels.view(-1)
145
+ # Enable model parallelism
146
+ shift_labels = shift_labels.to(shift_logits.device)
147
+ text_loss = loss_fct(shift_logits, shift_labels)
148
+ num_text_tokens = (shift_labels != -100).sum().item()
149
+
150
+ if additional_image_indexs is None:
151
+ return CausalLMOutputWithPast(
152
+ loss=text_loss,
153
+ logits=logits,
154
+ past_key_values=outputs.past_key_values,
155
+ hidden_states=outputs.hidden_states,
156
+ attentions=outputs.attentions,
157
+ )
158
+
159
+ to_image_mask = data_types == 1 # where to get t2i loss in each batch [True, False, False, True....]
160
+
161
+ if len(additional_image_indexs) > 0 and len(to_image_mask) == len(hidden_states): # image generation loss
162
+ to_image_states = hidden_states[to_image_mask]
163
+
164
+ # assert len(to_image_states) == len(additional_image_indexs)
165
+ if len(to_image_states) != len(additional_image_indexs):
166
+ print('to_image_mask', to_image_mask)
167
+ print('additional_image_indexs', additional_image_indexs)
168
+ shift_image_states = torch.stack([state[start_id - 1:end_id - 1] for (start_id, end_id), state in
169
+ zip(additional_image_indexs, to_image_states)]) # Shift so that tokens < n predict n [bz, seq_len, hidden_dim]
170
+ base_tokens = shift_image_states
171
+
172
+ K = self.ar_head.num_codebooks
173
+ B, L, C = base_tokens.shape
174
+ base_tokens = base_tokens.reshape(B * L, 1, C)
175
+
176
+ targets = torch.cat(additional_image_labels, dim=0) # [B, K, L]
177
+ image_code_labels = targets
178
+ targets = targets.permute(0, 2, 1).reshape(B * L, K)[:, :-1]
179
+ index_embeddings = []
180
+ for i in range(K - 1):
181
+ index_embed = self.ar_head.codebooks[i](targets[:, i])
182
+ index_embeddings.append(index_embed)
183
+ index_embeddings = torch.stack(index_embeddings, dim=1)
184
+ # import pdb;pdb.set_trace()
185
+ h = torch.cat((base_tokens, index_embeddings), dim=1) # [B*L, K, C]
186
+
187
+ multicode_embedding = self.ar_head(
188
+ input_ids=None,
189
+ attention_mask=None,
190
+ position_ids=None,
191
+ past_key_values=None,
192
+ inputs_embeds=h,
193
+ use_cache=False,
194
+ output_attentions=False,
195
+ output_hidden_states=False,
196
+ return_dict=False,
197
+ cache_position=None,
198
+ )
199
+ image_logits = self.ar_head.linear_head(multicode_embedding)
200
+ image_logits = image_logits.reshape(B, L, K, -1).permute(0, 2, 1, 3) # [B, K, L, sub_vocab_size]
201
+ loss_fct = CrossEntropyLoss()
202
+ image_logits = image_logits.reshape(-1, self.ar_head.sub_vocab_size)
203
+ image_labels = image_code_labels.view(-1)
204
+ image_labels = image_labels.to(image_logits.device)
205
+ image_softmax_normalizer = image_logits.max(-1).values ** 2
206
+ image_z_loss = 0.00005 * image_softmax_normalizer.mean()
207
+ image_loss = loss_fct(image_logits, image_labels) + image_z_loss
208
+ num_image_tokens = image_labels.shape[0]
209
+ else:
210
+ if len(hidden_states) != len(to_image_mask):
211
+ print('to_image_mask', to_image_mask)
212
+ print('hidden_states', hidden_states.shape)
213
+ print('inputs_embeds', inputs_embeds.shape)
214
+ print('additional_image_indexs', additional_image_indexs)
215
+ fake_ids = torch.ones(1, self.model.multi_embedder.num_codebooks - 1).to(inputs_embeds).long()
216
+ index_embeddings = []
217
+ for i in range(self.model.multi_embedder.num_codebooks - 1):
218
+ index_embed = self.ar_head.codebooks[i](fake_ids[:, i])
219
+ index_embeddings.append(index_embed)
220
+ index_embeddings = torch.stack(index_embeddings, dim=1)
221
+
222
+ multicode_embedding = self.ar_head(
223
+ input_ids=None,
224
+ attention_mask=None,
225
+ position_ids=None,
226
+ past_key_values=None,
227
+ inputs_embeds=index_embeddings,
228
+ use_cache=False,
229
+ output_attentions=False,
230
+ output_hidden_states=False,
231
+ return_dict=False,
232
+ cache_position=None,
233
+ )
234
+ image_logits = self.ar_head.linear_head(multicode_embedding)
235
+
236
+ num_image_tokens = 0
237
+ image_loss = (image_logits * 0).sum() # + (base_tokens*0).sum()
238
+ pass
239
+
240
+ loss = image_loss * (num_image_tokens / (num_image_tokens + num_text_tokens)) + \
241
+ text_loss * (num_text_tokens / (num_image_tokens + num_text_tokens))
242
+
243
+ # t2i_ratio = to_image_mask.sum() / len(to_image_mask)
244
+ # loss = image_loss * t2i_ratio + text_loss * (1 - t2i_ratio)
245
+
246
+ if not return_dict:
247
+ output = (logits,) + outputs[1:]
248
+ return (loss,) + output if loss is not None else output
249
+
250
+ return CausalLMOutputWithPast(
251
+ loss=loss,
252
+ logits=logits,
253
+ past_key_values=outputs.past_key_values,
254
+ hidden_states=outputs.hidden_states,
255
+ attentions=outputs.attentions,
256
+ )
257
+
258
+ @torch.no_grad()
259
+ def generate_mllm(
260
+ self,
261
+ inputs: Optional[torch.Tensor] = None,
262
+ images: Optional[torch.Tensor] = None,
263
+ images_aux: Optional[torch.FloatTensor] = None,
264
+ **kwargs,
265
+ ) -> Union[GenerateOutput, torch.LongTensor]:
266
+ position_ids = kwargs.pop("position_ids", None)
267
+ attention_mask = kwargs.pop("attention_mask", None)
268
+ if "inputs_embeds" in kwargs:
269
+ raise NotImplementedError("`inputs_embeds` is not supported")
270
+ # import pdb;pdb.set_trace()
271
+ if images is not None:
272
+ (
273
+ inputs,
274
+ position_ids,
275
+ attention_mask,
276
+ _,
277
+ inputs_embeds,
278
+ _
279
+ ) = self.prepare_inputs_for_multimodal(
280
+ inputs,
281
+ position_ids,
282
+ attention_mask,
283
+ None,
284
+ None,
285
+ images,
286
+ images_aux
287
+ )
288
+ else:
289
+ inputs_embeds = self.get_model().embed_tokens(inputs)
290
+ # import pdb;pdb.set_trace()
291
+ return super().generate(
292
+ position_ids=position_ids,
293
+ attention_mask=attention_mask,
294
+ inputs_embeds=inputs_embeds,
295
+ **kwargs
296
+ )
297
+
298
+ @torch.no_grad()
299
+ def generate(
300
+ self,
301
+ inputs: Optional[torch.Tensor] = None,
302
+ images: Optional[torch.Tensor] = None,
303
+ images_aux: Optional[torch.FloatTensor] = None,
304
+ **kwargs,
305
+ ) -> Union[GenerateOutput, torch.LongTensor]:
306
+ position_ids = kwargs.pop("position_ids", None)
307
+ attention_mask = kwargs.pop("attention_mask", None)
308
+ if "inputs_embeds" in kwargs:
309
+ raise NotImplementedError("`inputs_embeds` is not supported")
310
+
311
+ if images is not None:
312
+ (
313
+ inputs,
314
+ position_ids,
315
+ attention_mask,
316
+ _,
317
+ inputs_embeds,
318
+ _
319
+ ) = self.prepare_inputs_for_multimodal(
320
+ inputs,
321
+ position_ids,
322
+ attention_mask,
323
+ None,
324
+ None,
325
+ images,
326
+ images_aux
327
+ )
328
+ else:
329
+ inputs_embeds = self.get_model().embed_tokens(inputs)
330
+
331
+ return super().generate(
332
+ position_ids=position_ids,
333
+ attention_mask=attention_mask,
334
+ inputs_embeds=inputs_embeds,
335
+ **kwargs
336
+ )
337
+
338
+ def test_forward(
339
+ self,
340
+ input_ids: torch.LongTensor = None,
341
+ attention_mask: Optional[torch.Tensor] = None,
342
+ position_ids: Optional[torch.LongTensor] = None,
343
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
344
+ inputs_embeds: Optional[torch.FloatTensor] = None,
345
+ labels: Optional[torch.LongTensor] = None,
346
+ input_multi_ids: torch.LongTensor = None,
347
+ data_types: torch.LongTensor = None,
348
+ use_cache: Optional[bool] = None,
349
+ cache_position: Optional[torch.LongTensor] = None,
350
+ output_attentions: Optional[bool] = None,
351
+ output_hidden_states: Optional[bool] = None,
352
+ images: Optional[torch.FloatTensor] = None,
353
+ images_aux: Optional[torch.FloatTensor] = None,
354
+ return_dict: Optional[bool] = None,
355
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
356
+ # import pdb;pdb.set_trace()
357
+ if input_multi_ids is not None:
358
+ input_multi_ids = input_multi_ids.unsqueeze(-1) # [B,K,1]
359
+ input_ids = None # [B,1]
360
+ inputs_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C]
361
+
362
+ outputs = self.model(
363
+ input_ids=input_ids,
364
+ attention_mask=attention_mask,
365
+ position_ids=position_ids,
366
+ past_key_values=past_key_values,
367
+ inputs_embeds=inputs_embeds,
368
+ use_cache=use_cache,
369
+ output_attentions=output_attentions,
370
+ output_hidden_states=output_hidden_states,
371
+ return_dict=return_dict,
372
+ )
373
+ return outputs
374
+
375
+ def T2I_forward_nocache(
376
+ self,
377
+ input_ids: torch.LongTensor = None,
378
+ attention_mask: Optional[torch.Tensor] = None,
379
+ position_ids: Optional[torch.LongTensor] = None,
380
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
381
+ inputs_embeds: Optional[torch.FloatTensor] = None,
382
+ labels: Optional[torch.LongTensor] = None,
383
+ input_multi_ids: torch.LongTensor = None,
384
+ data_types: torch.LongTensor = None,
385
+ use_cache: Optional[bool] = None,
386
+ cache_position: Optional[torch.LongTensor] = None,
387
+ output_attentions: Optional[bool] = None,
388
+ output_hidden_states: Optional[bool] = None,
389
+ images: Optional[torch.FloatTensor] = None,
390
+ images_aux: Optional[torch.FloatTensor] = None,
391
+ return_dict: Optional[bool] = None,
392
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
393
+ # import pdb;pdb.set_trace()
394
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
395
+ output_hidden_states = (
396
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
397
+ )
398
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
399
+
400
+ if input_multi_ids is not None:
401
+ inputs_text_embeds = self.get_model().embed_tokens(input_ids)
402
+ input_ids = None # [B,1]
403
+ inputs_image_embeds = self.model.multi_embedder(input_multi_ids) # [B,1,C]
404
+ inputs_image_mask = torch.empty(inputs_image_embeds.shape[0], inputs_image_embeds.shape[1]).fill_(1).to(
405
+ attention_mask)
406
+ inputs_embeds = torch.cat([inputs_text_embeds, inputs_image_embeds], dim=1)
407
+ attention_mask = torch.cat([attention_mask, inputs_image_mask], dim=1)
408
+ position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).repeat(
409
+ inputs_embeds.shape[0], 1)
410
+ else:
411
+ inputs_embeds = self.get_model().embed_tokens(input_ids)
412
+ input_ids = None
413
+
414
+ outputs = self.model(
415
+ input_ids=input_ids,
416
+ attention_mask=attention_mask,
417
+ position_ids=position_ids,
418
+ past_key_values=past_key_values,
419
+ inputs_embeds=inputs_embeds,
420
+ use_cache=use_cache,
421
+ output_attentions=output_attentions,
422
+ output_hidden_states=output_hidden_states,
423
+ return_dict=return_dict,
424
+ )
425
+ return outputs
426
+
427
+ def T2I_forward_withcache(
428
+ self,
429
+ input_ids: torch.LongTensor = None,
430
+ attention_mask: Optional[torch.Tensor] = None,
431
+ position_ids: Optional[torch.LongTensor] = None,
432
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
433
+ inputs_embeds: Optional[torch.FloatTensor] = None,
434
+ labels: Optional[torch.LongTensor] = None,
435
+ input_multi_ids: torch.LongTensor = None,
436
+ data_types: torch.LongTensor = None,
437
+ use_cache: Optional[bool] = None,
438
+ cache_position: Optional[torch.LongTensor] = None,
439
+ output_attentions: Optional[bool] = None,
440
+ output_hidden_states: Optional[bool] = None,
441
+ images: Optional[torch.FloatTensor] = None,
442
+ images_aux: Optional[torch.FloatTensor] = None,
443
+ return_dict: Optional[bool] = None,
444
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
445
+ # import pdb;pdb.set_trace()
446
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
447
+ output_hidden_states = (
448
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
449
+ )
450
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
451
+
452
+ if input_multi_ids is not None:
453
+ inputs_image_embeds = self.model.multi_embedder(input_multi_ids[:, :, -1:]) # [B,1,C]
454
+ inputs_embeds = inputs_image_embeds
455
+ input_ids = None # [B,1]
456
+ else:
457
+ inputs_embeds = self.get_model().embed_tokens(input_ids)
458
+ input_ids = None
459
+
460
+ outputs = self.model(
461
+ input_ids=input_ids,
462
+ attention_mask=attention_mask,
463
+ position_ids=position_ids,
464
+ past_key_values=past_key_values,
465
+ inputs_embeds=inputs_embeds,
466
+ use_cache=use_cache,
467
+ output_attentions=output_attentions,
468
+ output_hidden_states=output_hidden_states,
469
+ return_dict=return_dict,
470
+ )
471
+ return outputs
472
+
473
+
474
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
475
+ images = kwargs.pop("images", None)
476
+ images_aux = kwargs.pop("images_aux", None)
477
+ _inputs = super().prepare_inputs_for_generation(
478
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
479
+ )
480
+ if images is not None:
481
+ _inputs['images'] = images
482
+ if images_aux is not None:
483
+ _inputs['images_aux'] = images_aux
484
+ return _inputs
485
+
486
+
487
+ AutoConfig.register("mini_gemini", MiniGeminiConfig)
488
+ AutoModelForCausalLM.register(MiniGeminiConfig, MiniGeminiLlamaForCausalLM)
model/liquid.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ------------------------------------------------------------------------
15
+ # Modified from LLaVA (https://github.com/haotian-liu/LLaVA)
16
+ # Copyright 2024 Yanwei Li
17
+ # ------------------------------------------------------------------------
18
+ # Modified from MiniGemini (https://github.com/dvlab-research/MGM)
19
+ # Copyright 2025 ByteDance
20
+ # ------------------------------------------------------------------------
21
+
22
+ import os
23
+ import json
24
+ import torch
25
+ import deepspeed
26
+ import safetensors
27
+ import transformers
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from abc import ABC, abstractmethod
31
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
32
+
33
+ from model.quant import VectorQuantizerM, AttnProjection
34
+ from model.multimodal_projector.builder import build_vision_projector
35
+ from model.multimodal_encoder.builder import build_vision_tower, build_vision_tower_aux
36
+ from constants import (
37
+ DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN,
38
+ IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN
39
+ )
40
+
41
+
42
+ IS_NEW_TRANSFORMERS = transformers.__version__ >= "4.34.0"
43
+
44
+
45
+ class MiniGeminiMetaModel:
46
+ def __init__(self, config):
47
+ super(MiniGeminiMetaModel, self).__init__(config)
48
+ self.config = config
49
+ self.multi_embedder = TokenEmbedder(self.config.hidden_size)
50
+ if hasattr(config, "mm_vision_tower"):
51
+ self.vision_tower = build_vision_tower(config, delay_load=True)
52
+ self.mm_projector = build_vision_projector(config)
53
+ if hasattr(config, "mm_vision_tower_aux"):
54
+ self.vision_tower_aux = build_vision_tower_aux(config, delay_load=True)
55
+
56
+ def get_vision_tower(self):
57
+ vision_tower = getattr(self, 'vision_tower', None)
58
+ if type(vision_tower) is list:
59
+ vision_tower = vision_tower[0]
60
+ return vision_tower
61
+
62
+ def get_vision_tower_aux(self):
63
+ vision_tower_aux = getattr(self, 'vision_tower_aux', None)
64
+ if type(vision_tower_aux) is list:
65
+ vision_tower_aux = vision_tower_aux[0]
66
+ return vision_tower_aux
67
+
68
+ def initialize_embedder(self, unitok_pth, mm_projecter_pth=None):
69
+ self.multi_embedder = TokenEmbedder(self.config.hidden_size)
70
+
71
+ if unitok_pth is not None:
72
+ ckpt = torch.load(unitok_pth, map_location='cpu')
73
+ unitok_ckpt = ckpt['trainer']['unitok']
74
+ quantizer_weights = dict()
75
+ for k, v in unitok_ckpt.items():
76
+ if k.startswith('quantizer'):
77
+ new_k = k.replace('quantizer.', '')
78
+ quantizer_weights[new_k] = v
79
+ attn_proj_weights = dict()
80
+ for k, v in unitok_ckpt.items():
81
+ if k.startswith('post_quant_proj'):
82
+ new_k = k.replace('post_quant_proj.', '')
83
+ attn_proj_weights[new_k] = v
84
+
85
+ if is_deepspeed_zero3_enabled():
86
+ with deepspeed.zero.GatheredParameters(quantizer_weights, modifier_rank=0):
87
+ if torch.distributed.get_rank() == 0:
88
+ self.multi_embedder.quantizer.load_state_dict(quantizer_weights)
89
+ with deepspeed.zero.GatheredParameters(attn_proj_weights, modifier_rank=0):
90
+ if torch.distributed.get_rank() == 0:
91
+ self.multi_embedder.attn_projection.load_state_dict(attn_proj_weights)
92
+ else:
93
+ status = self.multi_embedder.quantizer.load_state_dict(quantizer_weights)
94
+ print('missing_keys:', status.missing_keys)
95
+ status = self.multi_embedder.attn_projection.load_state_dict(attn_proj_weights)
96
+ print('missing_keys:', status.missing_keys)
97
+
98
+ if mm_projecter_pth is not None:
99
+ mm_projector_weights = torch.load(mm_projecter_pth, map_location='cpu')
100
+
101
+ def get_w(weights, keyword):
102
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
103
+
104
+ named_parameters = get_w(mm_projector_weights, 'mm_projector')
105
+
106
+ if is_deepspeed_zero3_enabled():
107
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
108
+ if torch.distributed.get_rank() == 0:
109
+ self.multi_embedder.mm_projector.load_state_dict(named_parameters)
110
+ else:
111
+ status = self.multi_embedder.mm_projector.load_state_dict(named_parameters)
112
+ print('missing_keys:', status.missing_keys)
113
+
114
+ self.multi_embedder = self.multi_embedder.to(device='cuda')
115
+
116
+ def initialize_vision_modules(self, model_args, fsdp=None):
117
+ vision_tower = model_args.vision_tower
118
+ vision_tower_aux = model_args.vision_tower_aux
119
+ mm_vision_select_layer = model_args.mm_vision_select_layer
120
+ mm_vision_select_feature = model_args.mm_vision_select_feature
121
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
122
+
123
+ self.config.mm_vision_tower = vision_tower
124
+ self.config.mm_vision_tower_aux = vision_tower_aux
125
+
126
+ if self.get_vision_tower() is None:
127
+ vision_tower = build_vision_tower(model_args)
128
+
129
+ if fsdp is not None and len(fsdp) > 0:
130
+ self.vision_tower = [vision_tower]
131
+ else:
132
+ self.vision_tower = vision_tower
133
+ else:
134
+ if fsdp is not None and len(fsdp) > 0:
135
+ vision_tower = self.vision_tower[0]
136
+ else:
137
+ vision_tower = self.vision_tower
138
+ vision_tower.load_model()
139
+
140
+ if vision_tower_aux is not None:
141
+ if self.get_vision_tower_aux() is None:
142
+ vision_tower_aux = build_vision_tower_aux(model_args)
143
+
144
+ if fsdp is not None and len(fsdp) > 0:
145
+ self.vision_tower_aux = [vision_tower_aux]
146
+ else:
147
+ self.vision_tower_aux = vision_tower_aux
148
+ else:
149
+ if fsdp is not None and len(fsdp) > 0:
150
+ vision_tower_aux = self.vision_tower_aux[0]
151
+ else:
152
+ vision_tower_aux = self.vision_tower_aux
153
+ vision_tower_aux.load_model()
154
+ self.config.mm_hidden_size_aux = vision_tower_aux.hidden_size
155
+
156
+ self.config.use_mm_proj = True
157
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
158
+ self.config.mm_hidden_size = vision_tower.hidden_size
159
+ self.config.mm_vision_select_layer = mm_vision_select_layer
160
+ self.config.mm_vision_select_feature = mm_vision_select_feature
161
+
162
+ if getattr(self, 'mm_projector', None) is None:
163
+ self.mm_projector = build_vision_projector(self.config)
164
+ else:
165
+ # In case it is frozen by LoRA
166
+ for p in self.mm_projector.parameters():
167
+ p.requires_grad = True
168
+
169
+ if pretrain_mm_mlp_adapter is not None:
170
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
171
+
172
+ def get_w(weights, keyword):
173
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
174
+
175
+ if 'model' in mm_projector_weights.keys():
176
+ mm_projector_weights = mm_projector_weights['model']
177
+ if is_deepspeed_zero3_enabled():
178
+ if len(mm_projector_weights) > 0:
179
+ with deepspeed.zero.GatheredParameters(mm_projector_weights, modifier_rank=0):
180
+ if torch.distributed.get_rank() == 0:
181
+ self.mm_projector.load_state_dict(mm_projector_weights)
182
+ else:
183
+ status = self.mm_projector.load_state_dict(mm_projector_weights, strict=False)
184
+ print('missing_keys:', status.missing_keys)
185
+ else:
186
+ if is_deepspeed_zero3_enabled():
187
+ named_parameters = get_w(mm_projector_weights, 'mm_projector')
188
+ if len(named_parameters) > 0:
189
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
190
+ if torch.distributed.get_rank() == 0:
191
+ self.mm_projector.load_state_dict(named_parameters)
192
+ else:
193
+ status = self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'),
194
+ strict=False)
195
+ print('missing_keys:', status.missing_keys)
196
+ self.mm_projector = self.mm_projector.to(device='cuda')
197
+
198
+ def initialize_uni_modules(self, model_args, for_eval=False):
199
+ pretrain_mm_mlp_adapter = getattr(model_args, "pretrain_mm_mlp_adapter", None)
200
+ self.config.image_size_aux = getattr(model_args, 'image_size_aux', 320)
201
+ self.config.optimize_vision_tower = getattr(model_args, 'optimize_vision_tower', False)
202
+ self.config.optimize_vision_tower_aux = getattr(model_args, 'optimize_vision_tower_aux', False)
203
+
204
+ self.vlm_uni_query_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size),
205
+ nn.Linear(self.config.mm_hidden_size, self.config.mm_hidden_size))
206
+ self.vlm_uni_aux_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
207
+ nn.Linear(self.config.mm_hidden_size_aux,
208
+ self.config.mm_hidden_size))
209
+ self.vlm_uni_val_projector = nn.Sequential(nn.LayerNorm(self.config.mm_hidden_size_aux),
210
+ nn.Linear(self.config.mm_hidden_size_aux,
211
+ self.config.mm_hidden_size))
212
+
213
+ if pretrain_mm_mlp_adapter is not None:
214
+ projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
215
+ else:
216
+ trainable_module = ['vlm_uni', 'vision_fpn', 'vision_stages']
217
+ if hasattr(model_args, 'model_name_or_path'):
218
+ model_save_path = model_args.model_name_or_path
219
+ else:
220
+ model_save_path = model_args.model_path
221
+ model_idx_path = getattr(model_args, 'model_path', model_save_path)
222
+ if IS_NEW_TRANSFORMERS:
223
+ try:
224
+ weight_file = json.load(open(os.path.join(model_idx_path, 'model.safetensors.index.json'), 'r'))[
225
+ 'weight_map']
226
+ except:
227
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))[
228
+ 'weight_map']
229
+ else:
230
+ weight_file = json.load(open(os.path.join(model_idx_path, 'pytorch_model.bin.index.json'), 'r'))[
231
+ 'weight_map']
232
+ model_path = set(
233
+ [weight_file[_key] for _key in weight_file if any([_module in _key for _module in trainable_module])])
234
+ projector_weights = {}
235
+ for _model in model_path:
236
+ if not IS_NEW_TRANSFORMERS:
237
+ projector_weights.update(torch.load(os.path.join(model_idx_path, _model), map_location='cpu'))
238
+ else:
239
+ with safetensors.safe_open(os.path.join(model_idx_path, _model), framework="pt", device='cpu') as f:
240
+ for _key in f.keys():
241
+ projector_weights.update({_key: f.get_tensor(_key)})
242
+ if len(projector_weights) == 0:
243
+ return
244
+
245
+ def get_w(weights, keyword, main_module, sub_module):
246
+ if getattr(main_module, sub_module, None) is None:
247
+ return
248
+
249
+ pretrain_weight = {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword + '.' in k}
250
+ if len(pretrain_weight) == 0:
251
+ return
252
+ if is_deepspeed_zero3_enabled():
253
+ named_parameters = [v for k, v in getattr(main_module, sub_module).named_parameters()]
254
+ if len(named_parameters) > 0:
255
+ # because zero3 puts placeholders in model params, this context
256
+ # manager gathers (unpartitions) the params of the current layer, then loads from
257
+ # the state dict and then re-partitions them again
258
+ with deepspeed.zero.GatheredParameters(named_parameters, modifier_rank=0):
259
+ if torch.distributed.get_rank() == 0:
260
+ getattr(main_module, sub_module).load_state_dict(pretrain_weight)
261
+ with deepspeed.zero.GatheredParameters(self.mm_projector[0].weight, modifier_rank=None):
262
+ weight_type = self.mm_projector[0].weight.dtype
263
+ device_type = self.mm_projector[0].weight.device
264
+ else:
265
+ weight_type = self.mm_projector[0].weight.dtype
266
+ device_type = self.mm_projector[0].weight.device
267
+ getattr(main_module, sub_module).load_state_dict(pretrain_weight)
268
+ if weight_type == torch.uint8 or weight_type == torch.int8 or weight_type == torch.int16:
269
+ weight_type = torch.float16
270
+ getattr(main_module, sub_module).to(device=device_type, dtype=weight_type)
271
+ print(f"Loading {sub_module} weights...")
272
+
273
+ # load pretrained weights
274
+ get_w(projector_weights, 'vision_tower.vision_tower', self.vision_tower, 'vision_tower')
275
+
276
+ # load pretrained weights
277
+ if self.config.optimize_vision_tower_aux:
278
+ # not optimize vision stem, just used to check
279
+ get_w(projector_weights, 'vision_tower_aux.vision_stem', self.vision_tower_aux, 'vision_stem')
280
+ get_w(projector_weights, 'vision_tower_aux.vision_stages', self.vision_tower_aux, 'vision_stages')
281
+ get_w(projector_weights, 'vlm_uni_query_projector', self, 'vlm_uni_query_projector')
282
+ get_w(projector_weights, 'vlm_uni_aux_projector', self, 'vlm_uni_aux_projector')
283
+ get_w(projector_weights, 'vlm_uni_val_projector', self, 'vlm_uni_val_projector')
284
+
285
+
286
+ class TokenEmbedder(nn.Module):
287
+ def __init__(self, hidden_size):
288
+ super().__init__()
289
+ # hard coding for unitok, need to be fixed
290
+ self.num_codebooks = 8
291
+ self.quantizer = VectorQuantizerM(32768, 64, 0.25, False, 0.01, 8)
292
+ self.attn_projection = AttnProjection(64, 1024, 16)
293
+ self.mm_projector = nn.Sequential(
294
+ nn.LayerNorm(1024, eps=1e-6),
295
+ nn.Linear(1024, hidden_size),
296
+ nn.GELU(),
297
+ nn.Linear(hidden_size, hidden_size),
298
+ )
299
+
300
+ def forward(self, indices): # input [bz,num-codebook,256]
301
+ assert indices.shape[1] == self.num_codebooks
302
+ features = self.quantizer.idx_to_f(indices) # [bz,256,C]
303
+ features = self.attn_projection(features) # [bz,256,1024]
304
+ latent_features = self.mm_projector(features) # [bz,256,hidden_size]
305
+ return latent_features # [bz,256,hidden_size
306
+
307
+
308
+ class MiniGeminiMetaForCausalLM(ABC):
309
+ @abstractmethod
310
+ def get_model(self):
311
+ pass
312
+
313
+ def get_vision_tower(self):
314
+ return self.get_model().get_vision_tower()
315
+
316
+ def get_vision_tower_aux(self):
317
+ return self.get_model().get_vision_tower_aux()
318
+
319
+ def encode_images(self, images, images_aux=None, is_video=False):
320
+ image_grid = getattr(self.config, 'image_grid', 1)
321
+ image_global = getattr(self.config, 'image_global', False)
322
+ if image_grid > 1:
323
+ batch_size = images.shape[0]
324
+ if image_global:
325
+ global_images = images[:, -1:].flatten(0, 1).contiguous()
326
+ grid_images = images[:, :-1].flatten(0, 1).contiguous()
327
+ images = torch.cat([grid_images, global_images], dim=0)
328
+ else:
329
+ images = images.flatten(0, 1).contiguous()
330
+
331
+ image_features = self.get_model().get_vision_tower()(images)
332
+
333
+ if image_global:
334
+ image_feat_global = image_features[-len(global_images):]
335
+ image_features = image_features[:len(grid_images)]
336
+
337
+ if images_aux is not None:
338
+ image_aux_features_raw = self.get_model().get_vision_tower_aux()(images_aux).to(
339
+ dtype=image_features.dtype, device=image_features.device)
340
+
341
+ if image_global:
342
+ image_aux_features_global = F.interpolate(image_aux_features_raw.float(),
343
+ scale_factor=1 / image_grid,
344
+ mode='bilinear',
345
+ align_corners=False).to(dtype=image_aux_features_raw.dtype)
346
+ image_feat_global, image_aux_feat_global = self.unified_resampler(image_feat_global,
347
+ image_aux_features_global)
348
+
349
+ if image_grid > 1:
350
+ image_aux_features_raw = image_aux_features_raw.reshape(*image_aux_features_raw.shape[:2],
351
+ image_grid,
352
+ image_aux_features_raw.shape[-2] // image_grid,
353
+ image_grid,
354
+ image_aux_features_raw.shape[-1] // image_grid)
355
+ image_aux_features_raw = image_aux_features_raw.permute(0, 2, 4, 1, 3, 5).flatten(1, 2).flatten(0,
356
+ 1).contiguous()
357
+ image_features, image_aux_features = self.unified_resampler(image_features, image_aux_features_raw)
358
+
359
+ if image_grid > 1:
360
+ image_features = image_features.reshape(batch_size, image_grid ** 2, *image_features.shape[1:])
361
+ image_features = image_features.flatten(1, 2).contiguous()
362
+ image_aux_features = image_aux_features.reshape(batch_size, image_grid ** 2,
363
+ *image_aux_features.shape[1:])
364
+ image_aux_features = image_aux_features.flatten(1, 2).contiguous()
365
+
366
+ # add global features, [global, local]
367
+ if image_global:
368
+ image_features = torch.cat([image_feat_global, image_features], dim=1)
369
+ image_aux_features = torch.cat([image_aux_feat_global, image_aux_features], dim=1)
370
+
371
+ # token generation
372
+ image_features = image_features + image_aux_features
373
+
374
+ # process image features after token generation
375
+ image_features = self.get_model().mm_projector(image_features)
376
+
377
+ return image_features
378
+
379
+ def unified_resampler(self, images, images_aux):
380
+ # patchwise with square images
381
+ patch_num = int(images.shape[1] ** 0.5)
382
+ patch_size = images_aux.shape[-1] // patch_num
383
+ # within patch attention
384
+ images_aux = images_aux.permute(0, 2, 3, 1)
385
+ images_aux = images_aux.reshape(len(images_aux), patch_num, patch_size, patch_num, patch_size,
386
+ images_aux.shape[-1])
387
+ images_aux = images_aux.permute(0, 1, 3, 2, 4, 5)
388
+ images_aux = images_aux.reshape(len(images_aux), patch_num ** 2, patch_size ** 2,
389
+ images_aux.shape[-1]).contiguous()
390
+
391
+ # token attention
392
+ embed_query = self.get_model().vlm_uni_query_projector(images)
393
+ embed_aux = self.get_model().vlm_uni_aux_projector(images_aux)
394
+ embed_value = self.get_model().vlm_uni_val_projector(images_aux)
395
+ embed_att = embed_query[:, :, None] @ (embed_aux.transpose(-1, -2) / (embed_aux.shape[-1] ** 0.5))
396
+ embed_att = embed_att.nan_to_num()
397
+ embed_feat = (embed_att.softmax(-1) @ embed_value).mean(2)
398
+
399
+ return images, embed_feat
400
+
401
+ def prepare_inputs_labels_for_multimodal(
402
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images=None, images_aux=None,
403
+ data_types=None,
404
+ ):
405
+ vision_tower = self.get_vision_tower()
406
+ multi_embedder = self.model.multi_embedder
407
+ # import pdb;pdb.set_trace()
408
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
409
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
410
+ 1] == 1:
411
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
412
+ attention_mask = torch.cat((attention_mask, torch.ones(
413
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
414
+ dtype=attention_mask.dtype,
415
+ device=attention_mask.device
416
+ )), dim=1)
417
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
418
+
419
+ if position_ids is None:
420
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
421
+ bug_flag = False
422
+ if images is not None:
423
+ _labels = labels
424
+ _position_ids = position_ids
425
+ _attention_mask = attention_mask
426
+ new_input_embeds = []
427
+ new_labels = []
428
+ additional_image_labels = []
429
+ additional_image_indexs = []
430
+ if attention_mask is not None:
431
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
432
+ zip(input_ids, attention_mask)]
433
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in
434
+ zip(labels, attention_mask)]
435
+ # import pdb;pdb.set_trace()
436
+ for image, cur_input_ids, cur_labels, data_type in zip(images, input_ids, labels, data_types):
437
+ # import pdb;pdb.set_trace()
438
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
439
+ # import pdb;pdb.set_trace()
440
+ if num_images == 0:
441
+ # import pdb;pdb.set_trace()
442
+ empty_image_embed = multi_embedder(
443
+ torch.zeros(1, self.model.multi_embedder.num_codebooks, 1).long().to(cur_input_ids))[0, :0]
444
+ new_input_embeds.append(
445
+ torch.cat([self.get_model().embed_tokens(cur_input_ids), empty_image_embed], dim=0))
446
+ new_labels.append(cur_labels)
447
+ continue # pure text data
448
+ assert len(image.shape) == 3 # [bz,num-codebook,256] image token id
449
+ if len(image) > num_images:
450
+ image = image[:num_images] # remove cutted images
451
+ image_embedding = multi_embedder(image) # get image embeddings
452
+
453
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
454
+ cur_input_ids.shape[0]]
455
+ cur_input_ids_noim = []
456
+ cur_labels_noim = []
457
+ for i in range(len(image_token_indices) - 1):
458
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
459
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
460
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
461
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
462
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
463
+ cur_new_input_embeds = []
464
+ cur_new_labels = []
465
+ # import pdb;pdb.set_trace()
466
+ max_pos_id = 0
467
+ for i in range(num_images + 1):
468
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
469
+ cur_new_labels.append(cur_labels_noim[i])
470
+ # import pdb;pdb.set_trace()
471
+ max_pos_id += cur_input_embeds_no_im[i].shape[0]
472
+ if i < num_images:
473
+ cur_image_features = image_embedding[i]
474
+ cur_new_input_embeds.append(cur_image_features)
475
+
476
+ if data_type == 1: # to Image, loss on 4x image tokens
477
+ additional_image_labels.append(image)
478
+ additional_image_indexs.append((cur_new_labels[-1].shape[0],
479
+ cur_new_labels[-1].shape[0] + cur_image_features.shape[
480
+ 0]))
481
+ ### input: describe xxxx: boi 8*[256] (256 embedding) eoi eos
482
+ ### labels: -100 -100 -100 -100 -100 -100 -100 -100 -100 eoi eos
483
+ cur_new_labels.append(
484
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
485
+ dtype=cur_labels.dtype))
486
+ max_pos_id += cur_image_features.shape[0]
487
+
488
+ cur_new_input_embeds = [x.to(device=cur_input_embeds.device) for x in cur_new_input_embeds]
489
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
490
+ cur_new_labels = torch.cat(cur_new_labels)
491
+
492
+ new_input_embeds.append(cur_new_input_embeds)
493
+ new_labels.append(cur_new_labels)
494
+
495
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
496
+
497
+ if tokenizer_model_max_length is not None:
498
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
499
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
500
+
501
+ # Combine them
502
+ max_len = max(x.shape[0] for x in new_input_embeds)
503
+ batch_size = len(new_input_embeds)
504
+ assert len(new_labels) == len(data_types)
505
+ new_input_embeds_padded = []
506
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
507
+ device=new_labels[0].device)
508
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype,
509
+ device=attention_mask.device)
510
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
511
+ # import pdb;pdb.set_trace()
512
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
513
+ cur_len = cur_new_embed.shape[0]
514
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
515
+ new_input_embeds_padded.append(torch.cat((
516
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
517
+ device=cur_new_embed.device),
518
+ cur_new_embed
519
+ ), dim=0))
520
+ if cur_len > 0:
521
+ new_labels_padded[i, -cur_len:] = cur_new_labels
522
+ attention_mask[i, -cur_len:] = True
523
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
524
+ device=position_ids.device)
525
+ else:
526
+ new_input_embeds_padded.append(torch.cat((
527
+ cur_new_embed,
528
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
529
+ device=cur_new_embed.device)
530
+ ), dim=0))
531
+ if cur_len > 0:
532
+ new_labels_padded[i, :cur_len] = cur_new_labels
533
+ attention_mask[i, :cur_len] = True
534
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
535
+ device=position_ids.device)
536
+
537
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
538
+
539
+ if _labels is None:
540
+ new_labels = None
541
+ else:
542
+ new_labels = new_labels_padded
543
+
544
+ if _attention_mask is None:
545
+ attention_mask = None
546
+ else:
547
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
548
+
549
+ if _position_ids is None:
550
+ position_ids = None
551
+ # import pdb;pdb.set_trace()
552
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, data_types, additional_image_labels, additional_image_indexs
553
+
554
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
555
+
556
+ def prepare_inputs_for_multimodal(
557
+ self, input_ids, position_ids, attention_mask,
558
+ past_key_values, labels, images=None, images_aux=None, data_types=None,
559
+ ):
560
+ multi_embedder = self.model.multi_embedder
561
+ # import pdb;pdb.set_trace()
562
+ _labels = labels
563
+ _position_ids = position_ids
564
+ _attention_mask = attention_mask
565
+ if images is not None:
566
+ new_input_embeds = []
567
+ for image, cur_input_ids in zip(images, input_ids):
568
+ # import pdb;pdb.set_trace()
569
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
570
+ if num_images == 0:
571
+ new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
572
+ continue # pure text data
573
+ image_embedding = multi_embedder(image)
574
+ # import pdb;pdb.set_trace()
575
+
576
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
577
+ cur_input_ids.shape[0]]
578
+ cur_input_ids_noim = []
579
+ for i in range(len(image_token_indices) - 1):
580
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
581
+ split_sizes = [x.shape[0] for x in cur_input_ids_noim]
582
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
583
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
584
+ cur_new_input_embeds = []
585
+ # import pdb;pdb.set_trace()
586
+ max_pos_id = 0
587
+ for i in range(num_images + 1):
588
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
589
+ # import pdb;pdb.set_trace()
590
+ max_pos_id += cur_input_embeds_no_im[i].shape[0]
591
+ if i < num_images:
592
+ cur_image_features = image_embedding[i]
593
+ cur_new_input_embeds.append(cur_image_features)
594
+ max_pos_id += cur_image_features.shape[0]
595
+
596
+ cur_new_input_embeds = [x.to(device=cur_input_embeds.device) for x in cur_new_input_embeds]
597
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
598
+ new_input_embeds.append(cur_new_input_embeds)
599
+
600
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
601
+ if tokenizer_model_max_length is not None:
602
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
603
+ # import pdb;pdb.set_trace()
604
+ # Combine them
605
+ max_len = max(x.shape[0] for x in new_input_embeds)
606
+ batch_size = len(new_input_embeds)
607
+ new_input_embeds_padded = []
608
+
609
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
610
+ # import pdb;pdb.set_trace()
611
+ if _labels is None:
612
+ new_labels = None
613
+ else:
614
+ new_labels = new_labels_padded
615
+
616
+ if _attention_mask is None:
617
+ attention_mask = None
618
+ else:
619
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
620
+
621
+ if _position_ids is None:
622
+ position_ids = None
623
+ # import pdb;pdb.set_trace()
624
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
625
+
626
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
627
+ if model_args.mm_use_im_patch_token:
628
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
629
+ self.resize_token_embeddings(len(tokenizer))
630
+
631
+ if model_args.mm_use_im_start_end:
632
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
633
+ self.resize_token_embeddings(len(tokenizer))
634
+
635
+ if num_new_tokens > 0:
636
+ input_embeddings = self.get_input_embeddings().weight.data
637
+ output_embeddings = self.get_output_embeddings().weight.data
638
+
639
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
640
+ dim=0, keepdim=True)
641
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
642
+ dim=0, keepdim=True)
643
+
644
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
645
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
646
+
647
+ if model_args.tune_mm_mlp_adapter:
648
+ for p in self.get_input_embeddings().parameters():
649
+ p.requires_grad = True
650
+ for p in self.get_output_embeddings().parameters():
651
+ p.requires_grad = False
652
+
653
+ if model_args.pretrain_mm_mlp_adapter:
654
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
655
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
656
+ assert num_new_tokens == 2
657
+ if input_embeddings.shape == embed_tokens_weight.shape:
658
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
659
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
660
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
661
+ else:
662
+ raise ValueError(
663
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
664
+ elif model_args.mm_use_im_patch_token:
665
+ if model_args.tune_mm_mlp_adapter:
666
+ for p in self.get_input_embeddings().parameters():
667
+ p.requires_grad = False
668
+ for p in self.get_output_embeddings().parameters():
669
+ p.requires_grad = False
model/multimodal_encoder/__pycache__/builder.cpython-311.pyc ADDED
Binary file (2.35 kB). View file
 
model/multimodal_encoder/__pycache__/builder.cpython-39.pyc ADDED
Binary file (1.27 kB). View file
 
model/multimodal_encoder/__pycache__/clip_encoder.cpython-311.pyc ADDED
Binary file (5.92 kB). View file
 
model/multimodal_encoder/__pycache__/clip_encoder.cpython-39.pyc ADDED
Binary file (3.37 kB). View file
 
model/multimodal_encoder/__pycache__/eva_encoder.cpython-311.pyc ADDED
Binary file (34.2 kB). View file
 
model/multimodal_encoder/__pycache__/eva_encoder.cpython-39.pyc ADDED
Binary file (17.3 kB). View file
 
model/multimodal_encoder/__pycache__/openclip_encoder.cpython-311.pyc ADDED
Binary file (12.3 kB). View file
 
model/multimodal_encoder/__pycache__/openclip_encoder.cpython-39.pyc ADDED
Binary file (6.52 kB). View file
 
model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .clip_encoder import CLIPVisionTower
3
+ from .eva_encoder import EVAVisionTower
4
+ from .openclip_encoder import OpenCLIPVisionTower
5
+
6
+
7
+ def build_vision_tower(vision_tower_cfg, **kwargs):
8
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
9
+ image_processor = getattr(vision_tower_cfg, 'image_processor', getattr(vision_tower_cfg, 'image_processor', "../processor/clip-patch14-224"))
10
+
11
+ if not os.path.exists(vision_tower):
12
+ raise ValueError(f'Not find vision tower: {vision_tower}')
13
+
14
+ if "openai" in vision_tower.lower() or "ShareGPT4V" in vision_tower:
15
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16
+ elif "lavis" in vision_tower.lower() or "eva" in vision_tower.lower():
17
+ return EVAVisionTower(vision_tower, image_processor, args=vision_tower_cfg, **kwargs)
18
+ else:
19
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
20
+
21
+
22
+ def build_vision_tower_aux(vision_tower_cfg, **kwargs):
23
+ vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', getattr(vision_tower_cfg, 'vision_tower_aux', None))
24
+
25
+ if not os.path.exists(vision_tower_aux):
26
+ raise ValueError(f'Not find vision tower: {vision_tower_aux}')
27
+
28
+ if "openclip" in vision_tower_aux.lower():
29
+ return OpenCLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
30
+ elif "openai" in vision_tower_aux.lower():
31
+ return CLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
32
+ else:
33
+ raise ValueError(f'Unknown vision tower: {vision_tower_aux}')
model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+ from ..processor.video_processor import VideoFramesProcessor
6
+
7
+ class CLIPVisionTower(nn.Module):
8
+ def __init__(self, vision_tower, args, delay_load=False):
9
+ super().__init__()
10
+
11
+ self.is_loaded = False
12
+
13
+ self.vision_tower_name = vision_tower
14
+ self.select_layer = args.mm_vision_select_layer
15
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16
+ self.is_optimize = getattr(args, 'optimize_vision_tower', False)
17
+
18
+ if not delay_load:
19
+ self.load_model()
20
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
21
+ self.load_model()
22
+ else:
23
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
24
+
25
+ def load_model(self):
26
+ self.image_processor = VideoFramesProcessor.from_pretrained(self.vision_tower_name)
27
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
28
+ self.vision_tower.requires_grad_(False)
29
+
30
+ self.is_loaded = True
31
+
32
+ def feature_select(self, image_forward_outs):
33
+ image_features = image_forward_outs.hidden_states[self.select_layer]
34
+ if self.select_feature == 'patch':
35
+ image_features = image_features[:, 1:]
36
+ elif self.select_feature == 'cls_patch':
37
+ image_features = image_features
38
+ else:
39
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
40
+ return image_features
41
+
42
+ def image_forward(self, images):
43
+ if type(images) is list:
44
+ image_features = []
45
+ for image in images:
46
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
47
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
48
+ image_features.append(image_feature)
49
+ else:
50
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
51
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
52
+
53
+ return image_features
54
+
55
+ def forward(self, images):
56
+ if not self.is_optimize:
57
+ with torch.no_grad():
58
+ image_features = self.image_forward(images)
59
+ else:
60
+ image_features = self.image_forward(images)
61
+
62
+ return image_features
63
+
64
+ @property
65
+ def dummy_feature(self):
66
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67
+
68
+ @property
69
+ def dtype(self):
70
+ return self.vision_tower.dtype
71
+
72
+ @property
73
+ def device(self):
74
+ return self.vision_tower.device
75
+
76
+ @property
77
+ def config(self):
78
+ if self.is_loaded:
79
+ return self.vision_tower.config
80
+ else:
81
+ return self.cfg_only
82
+
83
+ @property
84
+ def hidden_size(self):
85
+ return self.config.hidden_size
86
+
87
+ @property
88
+ def num_patches(self):
89
+ return (self.config.image_size // self.config.patch_size) ** 2
model/multimodal_encoder/eva_encoder.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+ from transformers import CLIPImageProcessor, CLIPVisionConfig
18
+ from ..processor.video_processor import VideoFramesProcessor
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+ class DropPath(nn.Module):
30
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31
+ """
32
+ def __init__(self, drop_prob=None):
33
+ super(DropPath, self).__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training)
38
+
39
+ def extra_repr(self) -> str:
40
+ return 'p={}'.format(self.drop_prob)
41
+
42
+
43
+ class Mlp(nn.Module):
44
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
45
+ super().__init__()
46
+ out_features = out_features or in_features
47
+ hidden_features = hidden_features or in_features
48
+ self.fc1 = nn.Linear(in_features, hidden_features)
49
+ self.act = act_layer()
50
+ self.fc2 = nn.Linear(hidden_features, out_features)
51
+ self.drop = nn.Dropout(drop)
52
+
53
+ def forward(self, x):
54
+ x = self.fc1(x)
55
+ x = self.act(x)
56
+ # x = self.drop(x)
57
+ # commit this for the orignal BERT implement
58
+ x = self.fc2(x)
59
+ x = self.drop(x)
60
+ return x
61
+
62
+
63
+ class Attention(nn.Module):
64
+ def __init__(
65
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
66
+ proj_drop=0., window_size=None, attn_head_dim=None):
67
+ super().__init__()
68
+ self.num_heads = num_heads
69
+ head_dim = dim // num_heads
70
+ if attn_head_dim is not None:
71
+ head_dim = attn_head_dim
72
+ all_head_dim = head_dim * self.num_heads
73
+ self.scale = qk_scale or head_dim ** -0.5
74
+
75
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
76
+ if qkv_bias:
77
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
78
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ else:
80
+ self.q_bias = None
81
+ self.v_bias = None
82
+
83
+ if window_size:
84
+ self.window_size = window_size
85
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
86
+ self.relative_position_bias_table = nn.Parameter(
87
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88
+ # cls to token & token 2 cls & cls to cls
89
+
90
+ # get pair-wise relative position index for each token inside the window
91
+ coords_h = torch.arange(window_size[0])
92
+ coords_w = torch.arange(window_size[1])
93
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
94
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
95
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
96
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
97
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
98
+ relative_coords[:, :, 1] += window_size[1] - 1
99
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
100
+ relative_position_index = \
101
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
102
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
104
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
105
+ relative_position_index[0, 0] = self.num_relative_distance - 1
106
+
107
+ self.register_buffer("relative_position_index", relative_position_index)
108
+ else:
109
+ self.window_size = None
110
+ self.relative_position_bias_table = None
111
+ self.relative_position_index = None
112
+
113
+ self.attn_drop = nn.Dropout(attn_drop)
114
+ self.proj = nn.Linear(all_head_dim, dim)
115
+ self.proj_drop = nn.Dropout(proj_drop)
116
+
117
+ def forward(self, x, rel_pos_bias=None):
118
+ B, N, C = x.shape
119
+ qkv_bias = None
120
+ if self.q_bias is not None:
121
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
122
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
123
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
124
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
125
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
126
+
127
+ q = q * self.scale
128
+ attn = (q @ k.transpose(-2, -1))
129
+
130
+ if self.relative_position_bias_table is not None:
131
+ relative_position_bias = \
132
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
133
+ self.window_size[0] * self.window_size[1] + 1,
134
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
135
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
136
+ attn = attn + relative_position_bias.unsqueeze(0)
137
+
138
+ if rel_pos_bias is not None:
139
+ attn = attn + rel_pos_bias
140
+
141
+ attn = attn.softmax(dim=-1)
142
+ attn = self.attn_drop(attn)
143
+
144
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
145
+ x = self.proj(x)
146
+ x = self.proj_drop(x)
147
+ return x
148
+
149
+
150
+ class Block(nn.Module):
151
+
152
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
153
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
154
+ window_size=None, attn_head_dim=None):
155
+ super().__init__()
156
+ self.norm1 = norm_layer(dim)
157
+ self.attn = Attention(
158
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
159
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
160
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
161
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
162
+ self.norm2 = norm_layer(dim)
163
+ mlp_hidden_dim = int(dim * mlp_ratio)
164
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
165
+
166
+ if init_values is not None and init_values > 0:
167
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
168
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ else:
170
+ self.gamma_1, self.gamma_2 = None, None
171
+
172
+ def forward(self, x, rel_pos_bias=None):
173
+ if self.gamma_1 is None:
174
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
175
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
176
+ else:
177
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
178
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
179
+ return x
180
+
181
+
182
+ class PatchEmbed(nn.Module):
183
+ """ Image to Patch Embedding
184
+ """
185
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
186
+ super().__init__()
187
+ img_size = to_2tuple(img_size)
188
+ patch_size = to_2tuple(patch_size)
189
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
190
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
191
+ self.img_size = img_size
192
+ self.patch_size = patch_size
193
+ self.num_patches = num_patches
194
+
195
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
196
+
197
+ def forward(self, x, **kwargs):
198
+ B, C, H, W = x.shape
199
+ # FIXME look at relaxing size constraints
200
+ assert H == self.img_size[0] and W == self.img_size[1], \
201
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
202
+ x = self.proj(x).flatten(2).transpose(1, 2)
203
+ return x
204
+
205
+
206
+ class RelativePositionBias(nn.Module):
207
+
208
+ def __init__(self, window_size, num_heads):
209
+ super().__init__()
210
+ self.window_size = window_size
211
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
212
+ self.relative_position_bias_table = nn.Parameter(
213
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
214
+ # cls to token & token 2 cls & cls to cls
215
+
216
+ # get pair-wise relative position index for each token inside the window
217
+ coords_h = torch.arange(window_size[0])
218
+ coords_w = torch.arange(window_size[1])
219
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
220
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
221
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
222
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
223
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
224
+ relative_coords[:, :, 1] += window_size[1] - 1
225
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
226
+ relative_position_index = \
227
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
228
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
229
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
230
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
231
+ relative_position_index[0, 0] = self.num_relative_distance - 1
232
+
233
+ self.register_buffer("relative_position_index", relative_position_index)
234
+
235
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
236
+
237
+ def forward(self):
238
+ relative_position_bias = \
239
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
240
+ self.window_size[0] * self.window_size[1] + 1,
241
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
242
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
243
+
244
+
245
+ class VisionTransformer(nn.Module):
246
+ """ Vision Transformer with support for patch or hybrid CNN input stage
247
+ """
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
249
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
250
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
251
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
252
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
253
+ super().__init__()
254
+ self.image_size = img_size
255
+ self.num_classes = num_classes
256
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
257
+
258
+ self.patch_embed = PatchEmbed(
259
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
260
+ num_patches = self.patch_embed.num_patches
261
+
262
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
263
+ if use_abs_pos_emb:
264
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
265
+ else:
266
+ self.pos_embed = None
267
+ self.pos_drop = nn.Dropout(p=drop_rate)
268
+
269
+ if use_shared_rel_pos_bias:
270
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
271
+ else:
272
+ self.rel_pos_bias = None
273
+ self.use_checkpoint = use_checkpoint
274
+
275
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
276
+ self.use_rel_pos_bias = use_rel_pos_bias
277
+ self.blocks = nn.ModuleList([
278
+ Block(
279
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
280
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
281
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
282
+ for i in range(depth)])
283
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
284
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
285
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
286
+
287
+ if self.pos_embed is not None:
288
+ trunc_normal_(self.pos_embed, std=.02)
289
+ trunc_normal_(self.cls_token, std=.02)
290
+ # trunc_normal_(self.mask_token, std=.02)
291
+ # if isinstance(self.head, nn.Linear):
292
+ # trunc_normal_(self.head.weight, std=.02)
293
+ self.apply(self._init_weights)
294
+ self.fix_init_weight()
295
+ # if isinstance(self.head, nn.Linear):
296
+ # self.head.weight.data.mul_(init_scale)
297
+ # self.head.bias.data.mul_(init_scale)
298
+
299
+ def fix_init_weight(self):
300
+ def rescale(param, layer_id):
301
+ param.div_(math.sqrt(2.0 * layer_id))
302
+
303
+ for layer_id, layer in enumerate(self.blocks):
304
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
305
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
306
+
307
+ def _init_weights(self, m):
308
+ if isinstance(m, nn.Linear):
309
+ trunc_normal_(m.weight, std=.02)
310
+ if isinstance(m, nn.Linear) and m.bias is not None:
311
+ nn.init.constant_(m.bias, 0)
312
+ elif isinstance(m, nn.LayerNorm):
313
+ nn.init.constant_(m.bias, 0)
314
+ nn.init.constant_(m.weight, 1.0)
315
+
316
+ def get_classifier(self):
317
+ return self.head
318
+
319
+ def reset_classifier(self, num_classes, global_pool=''):
320
+ self.num_classes = num_classes
321
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
322
+
323
+ def forward_features(self, x):
324
+ x = self.patch_embed(x)
325
+ batch_size, seq_len, _ = x.size()
326
+
327
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
328
+ x = torch.cat((cls_tokens, x), dim=1)
329
+ if self.pos_embed is not None:
330
+ x = x + self.pos_embed
331
+ x = self.pos_drop(x)
332
+
333
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
334
+ for blk in self.blocks:
335
+ if self.use_checkpoint:
336
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
337
+ else:
338
+ x = blk(x, rel_pos_bias)
339
+ return x
340
+ # x = self.norm(x)
341
+
342
+ # if self.fc_norm is not None:
343
+ # t = x[:, 1:, :]
344
+ # return self.fc_norm(t.mean(1))
345
+ # else:
346
+ # return x[:, 0]
347
+
348
+ def forward(self, x):
349
+ x = self.forward_features(x)
350
+ # x = self.head(x)
351
+ return x
352
+
353
+ def get_intermediate_layers(self, x):
354
+ x = self.patch_embed(x)
355
+ batch_size, seq_len, _ = x.size()
356
+
357
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
358
+ x = torch.cat((cls_tokens, x), dim=1)
359
+ if self.pos_embed is not None:
360
+ x = x + self.pos_embed
361
+ x = self.pos_drop(x)
362
+
363
+ features = []
364
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
365
+ for blk in self.blocks:
366
+ x = blk(x, rel_pos_bias)
367
+ features.append(x)
368
+
369
+ return features
370
+
371
+ @property
372
+ def dtype(self):
373
+ return self.cls_token.dtype
374
+
375
+ @property
376
+ def device(self):
377
+ return self.cls_token.device
378
+
379
+ def get_num_layer(self, var_name=""):
380
+ if var_name in ("cls_token", "mask_token", "pos_embed"):
381
+ return 0
382
+ elif var_name.startswith("patch_embed"):
383
+ return 0
384
+ elif var_name.startswith("rel_pos_bias"):
385
+ return len(self.blocks) - 1
386
+ elif var_name.startswith("blocks"):
387
+ layer_id = int(var_name.split('.')[1])
388
+ return layer_id + 1
389
+ else:
390
+ return len(self.blocks)
391
+
392
+
393
+ def interpolate_pos_embed(model, checkpoint_model):
394
+ if 'pos_embed' in checkpoint_model:
395
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
396
+ embedding_size = pos_embed_checkpoint.shape[-1]
397
+ num_patches = model.patch_embed.num_patches
398
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
399
+ # height (== width) for the checkpoint position embedding
400
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
401
+ # height (== width) for the new position embedding
402
+ new_size = int(num_patches ** 0.5)
403
+ # class_token and dist_token are kept unchanged
404
+ if orig_size != new_size:
405
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
406
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
407
+ # only the position tokens are interpolated
408
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
409
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
410
+ pos_tokens = torch.nn.functional.interpolate(
411
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
412
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
413
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
414
+ checkpoint_model['pos_embed'] = new_pos_embed
415
+
416
+
417
+ def convert_weights_to_fp16(model: nn.Module):
418
+ """Convert applicable model parameters to fp16"""
419
+
420
+ def _convert_weights_to_fp16(l):
421
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
422
+ l.weight.data = l.weight.data.half()
423
+ if l.bias is not None:
424
+ l.bias.data = l.bias.data.half()
425
+
426
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
427
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
428
+ # tensor = getattr(l, attr)
429
+ # if tensor is not None:
430
+ # tensor.data = tensor.data.half()
431
+
432
+ model.apply(_convert_weights_to_fp16)
433
+
434
+ class EVAVisionTower(nn.Module):
435
+ def __init__(self, vision_tower, image_processor, args, use_checkpoint=False, drop_path_rate=0.0, delay_load=False, dtype=torch.float32):
436
+ super().__init__()
437
+
438
+ self.is_loaded = False
439
+ self.use_checkpoint = use_checkpoint
440
+ self.vision_tower_name = vision_tower
441
+ self.image_processor_name = image_processor
442
+ self.drop_path_rate = drop_path_rate
443
+ self.patch_size = 14
444
+ self.out_channel = 1408
445
+ if not delay_load:
446
+ self.load_model()
447
+
448
+ self.vision_config = CLIPVisionConfig.from_pretrained(image_processor)
449
+
450
+ def load_model(self):
451
+ # self.image_processor = CLIPImageProcessor.from_pretrained(self.image_processor_name)
452
+ self.image_processor = VideoFramesProcessor.from_pretrained(self.image_processor_name)
453
+ self.vision_tower = VisionTransformer(
454
+ img_size=self.image_processor.size['shortest_edge'],
455
+ patch_size=self.patch_size,
456
+ use_mean_pooling=False,
457
+ embed_dim=self.out_channel,
458
+ depth=39,
459
+ num_heads=self.out_channel//88,
460
+ mlp_ratio=4.3637,
461
+ qkv_bias=True,
462
+ drop_path_rate=self.drop_path_rate,
463
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
464
+ use_checkpoint=self.use_checkpoint,
465
+ )
466
+
467
+ state_dict = torch.load(self.vision_tower_name, map_location="cpu")
468
+ interpolate_pos_embed(self.vision_tower, state_dict)
469
+ incompatible_keys = self.vision_tower.load_state_dict(state_dict, strict=False)
470
+ print(incompatible_keys)
471
+ self.vision_tower.requires_grad_(False)
472
+
473
+ self.is_loaded = True
474
+
475
+ @torch.no_grad()
476
+ def forward(self, images):
477
+ if type(images) is list:
478
+ image_features = []
479
+ for image in images:
480
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
481
+ image_feature = image_forward_out.to(image.dtype)
482
+ image_features.append(image_feature)
483
+ else:
484
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
485
+ image_features = image_forward_outs.to(images.dtype)
486
+
487
+ return image_features
488
+
489
+ def feature_select(self, image_features):
490
+ # image_features = image_features.hidden_states[self.select_layer]
491
+ if self.select_feature == 'patch':
492
+ image_features = image_features[:, 1:]
493
+ elif self.select_feature == 'cls_patch':
494
+ image_features = image_features
495
+ else:
496
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
497
+ return image_features
498
+
499
+ @property
500
+ def dummy_feature(self):
501
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
502
+
503
+ @property
504
+ def dtype(self):
505
+ return self.vision_tower.dtype
506
+
507
+ @property
508
+ def device(self):
509
+ return self.vision_tower.device
510
+
511
+ @property
512
+ def config(self):
513
+ return self.vision_config
514
+
515
+ @property
516
+ def hidden_size(self):
517
+ return self.out_channel
518
+
519
+ @property
520
+ def num_patches(self):
521
+ return (self.image_processor.size['shortest_edge'] // self.patch_size) ** 2
522
+
523
+
524
+
525
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,model_path=None,precision="fp16"):
526
+ model = VisionTransformer(
527
+ img_size=img_size,
528
+ patch_size=14,
529
+ use_mean_pooling=False,
530
+ embed_dim=1408,
531
+ depth=39,
532
+ num_heads=1408//88,
533
+ mlp_ratio=4.3637,
534
+ qkv_bias=True,
535
+ drop_path_rate=drop_path_rate,
536
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
537
+ use_checkpoint=use_checkpoint,
538
+ )
539
+ # url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
540
+ # cached_file = download_cached_file(
541
+ # url, check_hash=False, progress=True
542
+ # )
543
+ state_dict = torch.load(model_path, map_location="cpu")
544
+ interpolate_pos_embed(model,state_dict)
545
+
546
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
547
+ print(incompatible_keys)
548
+
549
+ if precision == "fp16":
550
+ convert_weights_to_fp16(model)
551
+ return model
model/multimodal_encoder/openclip_encoder.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import json
6
+ import logging
7
+ import deepspeed
8
+ from pathlib import Path
9
+ from open_clip.factory import load_state_dict, get_model_config
10
+ from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed
11
+ from typing import Dict, Optional
12
+ from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
13
+
14
+
15
+ class OpenCLIPVisionTower(nn.Module):
16
+ def __init__(self, vision_tower, args, delay_load=False):
17
+ super().__init__()
18
+
19
+ self.is_loaded = False
20
+ self.vision_tower_name = vision_tower
21
+ self.vision_config = json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r'))
22
+ self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False)
23
+
24
+ if not delay_load:
25
+ self.load_model()
26
+
27
+ def load_model(self):
28
+ ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin')
29
+ if 'convnext' in self.vision_tower_name:
30
+ if 'large' in self.vision_tower_name and 'd-320' in self.vision_tower_name:
31
+ self.model_type = 'convnext_large_d_320'
32
+ self.model_channel = [192, 384, 768, 1536] # stage 0-3
33
+ elif 'base' in self.vision_tower_name and 'w-320' in self.vision_tower_name:
34
+ self.model_type = 'convnext_base_w_320'
35
+ self.model_channel = [128, 256, 512, 1024]
36
+ elif 'xxlarge' in self.vision_tower_name:
37
+ self.model_type = 'convnext_xxlarge'
38
+ self.model_channel = [384, 768, 1536, 3072]
39
+
40
+ clip_model = CLIP(**get_model_config(self.model_type))
41
+ clip_model.visual.trunk.norm_pre = None
42
+ clip_model.visual.trunk.head = None
43
+ clip_model.visual.head = None
44
+ print(f'Loading pretrained weights ({self.model_type}).')
45
+ load_checkpoint(clip_model, ckpt_path, strict=False)
46
+
47
+ self.is_loaded = True
48
+ # decompose stem and stages blocks in vision tower
49
+ self.vision_stem = clip_model.visual.trunk.stem
50
+ self.vision_stages = clip_model.visual.trunk.stages
51
+ self.vision_stem.requires_grad_(False)
52
+ self.vision_stages.requires_grad_(False)
53
+
54
+ def forward(self, images):
55
+ if type(images) is list:
56
+ image_features = []
57
+ for image in images:
58
+ image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
59
+ image_features.append(image_feature)
60
+ else:
61
+ image_features = self.backbone(images.to(device=self.device, dtype=self.dtype))
62
+
63
+ return image_features
64
+
65
+ def backbone(self, images):
66
+ if not self.is_optimize:
67
+ with torch.no_grad():
68
+ results = self.basic_forward(images)
69
+ else:
70
+ results = self.basic_forward(images)
71
+
72
+ target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1])
73
+ result_cat = []
74
+ for _stage in results:
75
+ if _stage == 'stage_0':
76
+ result_cat.append(results[_stage].contiguous())
77
+ else:
78
+ result_cat.append(F.interpolate(results[_stage].float().contiguous() ,
79
+ size=target_size,
80
+ mode='bilinear',
81
+ align_corners=False).to(dtype=results[_stage].dtype))
82
+ result_cat = torch.cat(result_cat, dim=1)
83
+
84
+ return result_cat.contiguous()
85
+
86
+ def basic_forward(self, images):
87
+ results = {}
88
+ x = self.vision_stem(images)
89
+ for _idx in range(len(self.vision_stages)):
90
+ x = self.vision_stages[_idx](x)
91
+ results[f'stage_{_idx}'] = x
92
+ return results
93
+
94
+ @property
95
+ def dummy_feature(self):
96
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
97
+
98
+ @property
99
+ def dtype(self):
100
+ return self.vision_stem[0].weight.dtype
101
+
102
+ @property
103
+ def device(self):
104
+ return self.vision_stem[0].weight.device
105
+
106
+ @property
107
+ def config(self):
108
+ return self.vision_config
109
+
110
+ @property
111
+ def hidden_size(self):
112
+ return sum(self.model_channel)
113
+
114
+ # modified function from open_clip to support zero3 stage
115
+ def load_checkpoint(model, checkpoint_path, strict=True):
116
+ if Path(checkpoint_path).suffix in ('.npz', '.npy'):
117
+ from open_clip.big_vision import load_big_vision_weights
118
+ load_big_vision_weights(model, checkpoint_path)
119
+ return {}
120
+
121
+ state_dict = load_state_dict(checkpoint_path)
122
+ # detect old format and make compatible with new format
123
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
124
+ state_dict = convert_to_custom_text_state_dict(state_dict)
125
+ # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
126
+ # if 'logit_bias' not in state_dict and model.logit_bias is not None:
127
+ # state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
128
+ # Certain text transformers no longer expect position_ids after transformers==4.31
129
+ position_id_key = 'text.transformer.embeddings.position_ids'
130
+ if position_id_key in state_dict and not hasattr(model, position_id_key):
131
+ del state_dict[position_id_key]
132
+ resize_pos_embed(state_dict, model)
133
+ # resize_text_pos_embed(state_dict, model)
134
+ #incompatible_keys = model.load_state_dict(state_dict, strict=strict)
135
+ if is_deepspeed_zero3_enabled():
136
+
137
+ error_msgs = []
138
+
139
+ def load(module: nn.Module, state_dict, prefix=""):
140
+ metadata = None
141
+
142
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
143
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
144
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this
145
+ # state_dict
146
+ if len([key for key in state_dict if key.startswith(prefix)]) > 0:
147
+ if is_deepspeed_zero3_enabled():
148
+ # In sharded models, each shard has only part of the full state_dict, so only gather
149
+ # parameters that are in the current state_dict.
150
+ named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
151
+ params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
152
+ if len(params_to_gather) > 0:
153
+ # because zero3 puts placeholders in model params, this context
154
+ # manager gathers (unpartitions) the params of the current layer, then loads from
155
+ # the state dict and then re-partitions them again
156
+ with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
157
+ if torch.distributed.get_rank() == 0:
158
+ module._load_from_state_dict(*args)
159
+ else:
160
+ module._load_from_state_dict(*args)
161
+
162
+ for name, child in module._modules.items():
163
+ if child is not None:
164
+ load(child, state_dict, prefix + name + ".")
165
+
166
+ load(model, state_dict)
167
+ incompatible_keys = []
168
+ else:
169
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
170
+ logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
171
+ return incompatible_keys
172
+
173
+ class CLIP(nn.Module):
174
+ output_dict: torch.jit.Final[bool]
175
+
176
+ def __init__(
177
+ self,
178
+ embed_dim: int,
179
+ vision_cfg: CLIPVisionCfg,
180
+ text_cfg: CLIPTextCfg,
181
+ quick_gelu: bool = False,
182
+ cast_dtype: Optional[torch.dtype] = None,
183
+ output_dict: bool = False,
184
+ ):
185
+ super().__init__()
186
+ self.output_dict = output_dict
187
+
188
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
model/multimodal_projector/__pycache__/builder.cpython-311.pyc ADDED
Binary file (3.59 kB). View file
 
model/multimodal_projector/__pycache__/builder.cpython-39.pyc ADDED
Binary file (2.02 kB). View file
 
model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+
5
+ class IdentityMap(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x, *args, **kwargs):
10
+ return x
11
+
12
+ @property
13
+ def config(self):
14
+ return {"mm_projector_type": 'identity'}
15
+
16
+
17
+ class SimpleResBlock(nn.Module):
18
+ def __init__(self, channels):
19
+ super().__init__()
20
+ self.pre_norm = nn.LayerNorm(channels)
21
+
22
+ self.proj = nn.Sequential(
23
+ nn.Linear(channels, channels),
24
+ nn.GELU(),
25
+ nn.Linear(channels, channels)
26
+ )
27
+ def forward(self, x):
28
+ x = self.pre_norm(x)
29
+ return x + self.proj(x)
30
+
31
+
32
+ def build_vision_projector(config, delay_load=False, **kwargs):
33
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
34
+
35
+ if projector_type == 'linear':
36
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
37
+
38
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
39
+ if mlp_gelu_match:
40
+ mlp_depth = int(mlp_gelu_match.group(1))
41
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
42
+ for _ in range(1, mlp_depth):
43
+ modules.append(nn.GELU())
44
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
45
+ return nn.Sequential(*modules)
46
+
47
+ if projector_type == 'identity':
48
+ return IdentityMap()
49
+
50
+ raise ValueError(f'Unknown projector type: {projector_type}')
model/processor/__pycache__/video_processor.cpython-311.pyc ADDED
Binary file (4.86 kB). View file
 
model/processor/__pycache__/video_processor.cpython-39.pyc ADDED
Binary file (2.84 kB). View file
 
model/processor/video_processor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPImageProcessor
2
+ from transformers.image_processing_utils import BatchFeature, get_size_dict
3
+ from transformers.image_transforms import get_resize_output_image_size
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import numpy as np
9
+
10
+
11
+ class VideoFramesProcessor(CLIPImageProcessor):
12
+
13
+ def __init__(self, **kwargs):
14
+ super().__init__(**kwargs)
15
+
16
+ def preprocess(self, images, **kwargs):
17
+ if not isinstance(images, np.ndarray):
18
+ return super().preprocess(images=images, **kwargs)
19
+
20
+ do_resize = kwargs.get('do_resize', self.do_resize)
21
+ size = kwargs.get('size', self.size)
22
+ size = get_size_dict(size, param_name="size", default_to_square=False)
23
+ do_center_crop = kwargs.get('do_center_crop', self.do_center_crop)
24
+ crop_size = kwargs.get('crop_size', self.crop_size)
25
+ crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
26
+ do_rescale = kwargs.get('do_rescale', self.do_rescale)
27
+ rescale_factor = kwargs.get('rescale_factor', self.rescale_factor)
28
+ do_normalize = kwargs.get('do_normalize', self.do_normalize)
29
+ image_mean = kwargs.get('image_mean', self.image_mean)
30
+ image_std = kwargs.get('image_std', self.image_std)
31
+ return_tensors = kwargs.get('return_tensors', None)
32
+
33
+ def resize(images, output_size):
34
+ images = images.permute((0, 3, 1, 2))
35
+ images = F.interpolate(images, size=output_size, mode='bicubic')
36
+ images = images.permute((0, 2, 3, 1))
37
+ return images
38
+
39
+ def center_crop(images, crop_size):
40
+ crop_width, crop_height = crop_size["width"], crop_size["height"]
41
+ img_width, img_height = images.shape[1:3]
42
+ x = (img_width - crop_width) // 2
43
+ y = (img_height - crop_height) // 2
44
+ images = images[:, x:x+crop_width, y:y+crop_height]
45
+ return images
46
+
47
+ def rescale(images, rescale_factor):
48
+ images = images * rescale_factor
49
+ return images
50
+
51
+ def normalize(images, mean, std):
52
+ mean = torch.tensor(mean)
53
+ std = torch.tensor(std)
54
+ images = (images - mean) / std
55
+ return images
56
+
57
+ images = torch.from_numpy(images).float()
58
+
59
+ if do_resize:
60
+ output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False)
61
+ images = resize(images, output_size)
62
+
63
+ if do_center_crop:
64
+ images = center_crop(images, crop_size)
65
+
66
+ if do_rescale:
67
+ images = rescale(images, rescale_factor)
68
+
69
+ if do_normalize:
70
+ images = normalize(images, image_mean, image_std)
71
+
72
+ images = images.permute((0, 3, 1, 2))
73
+ data = {"pixel_values": images}
74
+ return BatchFeature(data=data, tensor_type=return_tensors)
model/quant.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import torch
3
+ from torch import distributed as tdist, nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn.functional import scaled_dot_product_attention
6
+
7
+ # from utils import dist
8
+
9
+ # this file only provides the VectorQuantizer2 used in VQVAE
10
+ __all__ = ['VectorQuantizer', ]
11
+
12
+ def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
13
+ E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
14
+ E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
15
+ logits = -E_dist.float().mul_(inv_entropy_tau)
16
+ # calc per_sample_entropy
17
+ prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
18
+ per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
19
+ # calc codebook_entropy
20
+ avg_prob = prob.mean(dim=0) # (vocab_size,)
21
+ log_avg_prob = torch.log(avg_prob + 1e-7)
22
+ codebook_entropy = (-avg_prob * log_avg_prob).sum()
23
+ # calc entropy_loss
24
+ entropy_loss = per_sample_entropy - codebook_entropy
25
+ return entropy_loss
26
+
27
+
28
+ class NormalizedEmbedding(nn.Embedding):
29
+ def __init__(self, num_embeddings: int, embedding_dim: int):
30
+ super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
31
+ # self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
32
+
33
+ def forward(self, idx):
34
+ return F.embedding(
35
+ idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
36
+ self.norm_type, self.scale_grad_by_freq, self.sparse
37
+ )
38
+
39
+ def get_norm_weight(self):
40
+ return F.normalize(self.weight, dim=1)
41
+
42
+
43
+ class ResConv(nn.Conv2d):
44
+ def __init__(self, embed_dim, quant_resi):
45
+ ks = 3 if quant_resi < 0 else 1
46
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
47
+ self.resi_ratio = abs(quant_resi)
48
+
49
+ def forward(self, h_BChw):
50
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
51
+
52
+
53
+ class VectorQuantizer(nn.Module):
54
+ def __init__(
55
+ self, vocab_size: int, vocab_width: int, vocab_norm: bool, beta: float = 0.25, quant_resi=-0.5,
56
+ using_entropy_loss=False, entropy_temp=0.01,
57
+ ):
58
+ super().__init__()
59
+ self.vocab_size: int = vocab_size
60
+ self.vocab_width: int = vocab_width
61
+ self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
62
+ self.vocab_usage_record_times: int = 0
63
+
64
+ self.vocab_norm: bool = vocab_norm
65
+ # self.quant_resi = ResConv(self.vocab_width, quant_resi=quant_resi)
66
+ self.quant_resi = nn.Identity()
67
+ self.embedding = nn.Embedding(self.vocab_size, self.vocab_width)
68
+ self.beta: float = beta
69
+
70
+ self.using_entropy_loss, self.inv_entropy_tau = using_entropy_loss, 1 / entropy_temp
71
+ if not self.vocab_norm:
72
+ assert not self.using_entropy_loss, 'entropy loss without vocab norm is not supported'
73
+
74
+ def init_vocab(self, eini: float):
75
+ if eini > 0:
76
+ nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
77
+ elif eini < 0:
78
+ base = self.vocab_width ** -0.5
79
+ base /= 36
80
+ self.embedding.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
81
+
82
+ def extra_repr(self) -> str:
83
+ return f'beta={self.beta:g}'
84
+
85
+ # ===================== `forward` is only used in VAE training =====================
86
+ def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[
87
+ torch.Tensor, torch.Tensor, torch.Tensor, List[float]]:
88
+ f_BChw = f_BChw.float()
89
+ B, C, h, w = f_BChw.shape
90
+ if self.vocab_norm:
91
+ if self.using_entropy_loss:
92
+ # find the nearest neighbor
93
+ NxC = f_BChw.permute(0, 2, 3, 1).reshape(-1, C)
94
+ NxC_no_grad = NxC.detach()
95
+ NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
96
+ idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
97
+ # get logits
98
+ E_dist = NxC.square().sum(dim=1, keepdim=True) + self.embedding.weight.square().sum(dim=1,
99
+ keepdim=False)
100
+ E_dist.addmm_(NxC, self.embedding.weight.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
101
+ logits = -E_dist.float().mul_(self.inv_entropy_tau)
102
+ # calc per_sample_entropy
103
+ prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
104
+ per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
105
+ # calc codebook_entropy
106
+ avg_prob = prob.mean(dim=0) # (vocab_size,)
107
+ log_avg_prob = torch.log(avg_prob + 1e-7)
108
+ codebook_entropy = (-avg_prob * log_avg_prob).sum()
109
+ # calc entropy_loss
110
+ entropy_loss = per_sample_entropy - codebook_entropy
111
+ else:
112
+ NxC_no_grad = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C)
113
+ NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
114
+ idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
115
+ entropy_loss = 0
116
+ else: # not self.vocab_norm
117
+ NxC_no_grad = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C)
118
+ E_dist = NxC_no_grad.square().sum(dim=1, keepdim=True) + self.embedding.weight.data.square().sum(dim=1,
119
+ keepdim=False)
120
+ E_dist.addmm_(NxC_no_grad, self.embedding.weight.data.T, alpha=-2, beta=1) # E_dist: N x vocab_size
121
+ idx_N = torch.argmin(E_dist, dim=1)
122
+ entropy_loss = 0
123
+
124
+ prob_per_class_is_chosen = idx_N.bincount(minlength=self.vocab_size).float()
125
+ handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
126
+ self.training and dist.initialized()) else None
127
+
128
+ # look up
129
+ idx_Bhw = idx_N.view(B, h, w)
130
+ fhat_BChw = self.quant_resi(self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous())
131
+
132
+ # calc loss
133
+ vq_loss = F.mse_loss(fhat_BChw.detach(), f_BChw).mul_(self.beta) + F.mse_loss(fhat_BChw, f_BChw.detach())
134
+ fhat_BChw = (fhat_BChw.detach() - f_BChw.detach()).add_(f_BChw)
135
+
136
+ # update vocab_usage
137
+ if handler is not None:
138
+ handler.wait()
139
+ prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
140
+ vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
141
+
142
+ if self.vocab_usage_record_times == 0:
143
+ self.vocab_usage.copy_(prob_per_class_is_chosen)
144
+ elif self.vocab_usage_record_times < 100:
145
+ self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
146
+ else:
147
+ self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
148
+ self.vocab_usage_record_times += 1
149
+
150
+ return fhat_BChw, vq_loss, entropy_loss, (vocab_usage if ret_usages else None)
151
+
152
+ def f_to_idx(self, f_BChw: torch.Tensor) -> torch.LongTensor:
153
+ f_BChw = f_BChw.float()
154
+ B, C, h, w = f_BChw.shape
155
+ with torch.cuda.amp.autocast(enabled=False):
156
+ # find the nearest embedding
157
+ query_NxC = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C)
158
+ if self.vocab_norm:
159
+ query_NxC = F.normalize(query_NxC, dim=-1)
160
+ idx_N = torch.argmax(query_NxC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
161
+ else:
162
+ E_dist = torch.sum(query_NxC.square(), dim=1, keepdim=True) + torch.sum(
163
+ self.embedding.weight.data.square(), dim=1, keepdim=False)
164
+ E_dist.addmm_(query_NxC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
165
+ idx_N = torch.argmin(E_dist, dim=1)
166
+ return idx_N.view(B, h, w)
167
+
168
+
169
+ class VectorQuantizerHybrid(nn.Module):
170
+ def __init__(
171
+ self, vocab_size: int, vocab_width: int, vocab_norm: bool, beta: float = 0.25, quant_resi=-0.5,
172
+ using_entropy_loss=False, entropy_temp=0.01,
173
+ ):
174
+ super().__init__()
175
+ self.vocab_size: int = vocab_size
176
+ self.vocab_width: int = vocab_width
177
+ self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
178
+ self.vocab_usage_record_times: int = 0
179
+
180
+ self.vocab_norm: bool = vocab_norm
181
+ # self.quant_resi = ResConv(self.vocab_width, quant_resi=quant_resi)
182
+ self.embedding = nn.Embedding(self.vocab_size, self.vocab_width)
183
+ self.beta: float = beta
184
+
185
+ self.using_entropy_loss, self.inv_entropy_tau = using_entropy_loss, 1 / entropy_temp
186
+ if not self.vocab_norm:
187
+ assert not self.using_entropy_loss, 'entropy loss without vocab norm is not supported'
188
+
189
+ def init_vocab(self, eini: float):
190
+ if eini > 0:
191
+ nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
192
+ elif eini < 0:
193
+ base = self.vocab_width ** -0.5
194
+ base /= 36
195
+ self.embedding.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
196
+
197
+ def extra_repr(self) -> str:
198
+ return f'beta={self.beta:g}'
199
+
200
+ def forward(self, class_tokens, patch_tokens, ret_usages=False):
201
+ class_tokens = class_tokens.float()
202
+ patch_tokens = patch_tokens.float()
203
+
204
+ B, L, C = class_tokens.shape
205
+ B, C, H, W = patch_tokens.shape
206
+ patch_tokens = patch_tokens.flatten(start_dim=2).permute(0, 2, 1)
207
+ NxC = torch.cat((class_tokens, patch_tokens), dim=1).reshape(-1, C)
208
+ if self.vocab_norm:
209
+ if self.using_entropy_loss:
210
+ # find the nearest neighbor
211
+ NxC_no_grad = NxC.detach()
212
+ NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
213
+ idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
214
+ # get logits
215
+ E_dist = NxC.square().sum(dim=1, keepdim=True) + self.embedding.weight.square().sum(dim=1,
216
+ keepdim=False)
217
+ E_dist.addmm_(NxC, self.embedding.weight.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
218
+ logits = -E_dist.float().mul_(self.inv_entropy_tau)
219
+ # calc per_sample_entropy
220
+ prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
221
+ per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
222
+ # calc codebook_entropy
223
+ avg_prob = prob.mean(dim=0) # (vocab_size,)
224
+ log_avg_prob = torch.log(avg_prob + 1e-7)
225
+ codebook_entropy = (-avg_prob * log_avg_prob).sum()
226
+ # calc entropy_loss
227
+ entropy_loss = per_sample_entropy - codebook_entropy
228
+ else:
229
+ NxC_no_grad = NxC.detach()
230
+ NxC_no_grad = F.normalize(NxC_no_grad, dim=-1)
231
+ idx_N = torch.argmax(NxC_no_grad @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
232
+ entropy_loss = 0
233
+ else: # not self.vocab_norm
234
+ NxC_no_grad = NxC.detach()
235
+ E_dist = NxC_no_grad.square().sum(dim=1, keepdim=True) + self.embedding.weight.data.square().sum(dim=1,
236
+ keepdim=False)
237
+ E_dist.addmm_(NxC_no_grad, self.embedding.weight.data.T, alpha=-2, beta=1) # E_dist: N x vocab_size
238
+ idx_N = torch.argmin(E_dist, dim=1)
239
+ entropy_loss = 0
240
+
241
+ prob_per_class_is_chosen = idx_N.bincount(minlength=self.vocab_size).float()
242
+ handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
243
+ self.training and dist.initialized()) else None
244
+
245
+ # look up
246
+ fhat = self.embedding(idx_N)
247
+
248
+ # calc loss
249
+ vq_loss = F.mse_loss(fhat.detach(), NxC).mul_(self.beta) + F.mse_loss(fhat, NxC.detach())
250
+ fhat = (fhat.detach() - NxC.detach()).add_(NxC)
251
+
252
+ # update vocab_usage
253
+ if handler is not None:
254
+ handler.wait()
255
+ prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
256
+ vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
257
+
258
+ if self.vocab_usage_record_times == 0:
259
+ self.vocab_usage.copy_(prob_per_class_is_chosen)
260
+ elif self.vocab_usage_record_times < 100:
261
+ self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
262
+ else:
263
+ self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
264
+ self.vocab_usage_record_times += 1
265
+
266
+ fhat = fhat.view(B, -1, C)
267
+ fhat_class = fhat[:, :L, :]
268
+ fhat_patch = fhat[:, L:, :].view(B, H, W, C).permute(0, 3, 1, 2)
269
+
270
+ return fhat_class, fhat_patch, vq_loss, entropy_loss, (vocab_usage if ret_usages else None)
271
+
272
+ def f_to_idx(self, class_tokens, patch_tokens) -> torch.LongTensor:
273
+ B, L, C = class_tokens.shape
274
+ B, C, H, W = patch_tokens.shape
275
+ class_tokens = class_tokens.float()
276
+ patch_tokens = patch_tokens.float()
277
+ patch_tokens = patch_tokens.flatten(start_dim=2).permute(0, 2, 1)
278
+ NxC = torch.cat((class_tokens, patch_tokens), dim=1).reshape(-1, C)
279
+ with torch.cuda.amp.autocast(enabled=False):
280
+ # find the nearest embedding
281
+ if self.vocab_norm:
282
+ NxC = F.normalize(NxC, dim=-1)
283
+ idx_N = torch.argmax(NxC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
284
+ else:
285
+ E_dist = torch.sum(NxC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(),
286
+ dim=1, keepdim=False)
287
+ E_dist.addmm_(NxC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
288
+ idx_N = torch.argmin(E_dist, dim=1)
289
+ return idx_N
290
+
291
+
292
+ class VectorQuantizerX(nn.Module):
293
+ def __init__(
294
+ self,
295
+ vocab_size: int,
296
+ vocab_width: int,
297
+ beta: float = 0.25,
298
+ use_entropy_loss=False,
299
+ entropy_temp=0.01,
300
+ ):
301
+ super().__init__()
302
+ self.beta = beta
303
+ self.vocab_size = vocab_size
304
+ self.vocab_width = vocab_width
305
+ self.vocab_usage_record_times: int = 0
306
+ self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
307
+
308
+ self.codebook = NormalizedEmbedding(self.vocab_size, self.vocab_width)
309
+
310
+ self.use_entropy_loss = use_entropy_loss
311
+ self.inv_entropy_tau = 1 / entropy_temp
312
+
313
+ def init_vocab(self, eini: float):
314
+ if eini > 0:
315
+ nn.init.trunc_normal_(self.codebook.weight.data, std=eini)
316
+ elif eini < 0:
317
+ base = self.vocab_width ** -0.5
318
+ base /= 36
319
+ self.codebook.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
320
+
321
+ def extra_repr(self) -> str:
322
+ return f'beta={self.beta:g}'
323
+
324
+ def forward(self, features):
325
+ B, L, C = features.shape
326
+ features = features.reshape(-1, C)
327
+ features = F.normalize(features, dim=-1).float()
328
+ codebook_embed = self.codebook.get_norm_weight()
329
+ indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
330
+ entropy_loss = get_entropy_loss(features, codebook_embed, self.inv_entropy_tau) if self.use_entropy_loss else 0
331
+ features_hat = self.codebook(indices)
332
+
333
+ # calc loss
334
+ vq_loss = F.mse_loss(features_hat.detach(), features).mul_(self.beta) + F.mse_loss(features_hat,
335
+ features.detach())
336
+ features_hat = (features_hat.detach() - features.detach()).add_(features)
337
+
338
+ # update vocab_usage
339
+ prob_per_class_is_chosen = indices.bincount(minlength=self.vocab_size).float()
340
+ handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
341
+ self.training and dist.initialized()) else None
342
+ if handler is not None:
343
+ handler.wait()
344
+ prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
345
+ vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
346
+ if self.vocab_usage_record_times == 0:
347
+ self.vocab_usage.copy_(prob_per_class_is_chosen)
348
+ elif self.vocab_usage_record_times < 100:
349
+ self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
350
+ else:
351
+ self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
352
+ self.vocab_usage_record_times += 1
353
+
354
+ return features_hat.view(B, L, C), vq_loss, entropy_loss, vocab_usage
355
+
356
+ def f_to_idx(self, features):
357
+ B, L, C = features.shape
358
+ features = features.reshape(-1, C)
359
+ features = F.normalize(features, dim=-1).float()
360
+ codebook_embed = self.codebook.get_norm_weight().float()
361
+ indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
362
+ return indices.view(B, L)
363
+
364
+
365
+ class VectorQuantizerM(nn.Module):
366
+ def __init__(
367
+ self,
368
+ vocab_size,
369
+ vocab_width,
370
+ beta=0.25,
371
+ use_entropy_loss=False,
372
+ entropy_temp=0.01,
373
+ num_codebooks=16
374
+ ):
375
+ super().__init__()
376
+ self.num_codebooks = num_codebooks
377
+ self.codebooks = nn.ModuleList()
378
+ for _ in range(num_codebooks):
379
+ codebook = VectorQuantizerX(
380
+ vocab_size=vocab_size // num_codebooks,
381
+ vocab_width=vocab_width // num_codebooks,
382
+ beta=beta,
383
+ use_entropy_loss=use_entropy_loss,
384
+ entropy_temp=entropy_temp,
385
+ )
386
+ self.codebooks.append(codebook)
387
+
388
+ def init_vocab(self, eini: float):
389
+ for codebook in self.codebooks:
390
+ codebook.init_vocab(eini)
391
+
392
+ def f_to_idx(self, features):
393
+ indices = []
394
+ chunk_size = features.shape[-1] // self.num_codebooks
395
+ splited_features = features.split(chunk_size, dim=-1)
396
+ for i, codebook in enumerate(self.codebooks):
397
+ indices.append(codebook.f_to_idx(splited_features[i]))
398
+ indices = torch.stack(indices, dim=1)
399
+ return indices
400
+
401
+ def idx_to_f(self, indices):
402
+ assert indices.shape[1] == self.num_codebooks
403
+ latent_features = []
404
+ for i, codebook in enumerate(self.codebooks):
405
+ sub_indices = indices[:, i].flatten(start_dim=1)
406
+ latent_feature = codebook.codebook(sub_indices)
407
+ latent_features.append(latent_feature)
408
+ latent_features = torch.cat(latent_features, dim=-1)
409
+ return latent_features
410
+
411
+ def forward(self, features):
412
+ latent_features = []
413
+ global_vq_loss = 0.
414
+ global_entropy_loss = 0.
415
+ global_vocab_usage = 0.
416
+ chunk_size = features.shape[-1] // self.num_codebooks
417
+ splited_features = features.split(chunk_size, dim=-1)
418
+ for i, codebook in enumerate(self.codebooks):
419
+ latent_feature, vq_loss, entropy_loss, vocab_usage = codebook(splited_features[i])
420
+ latent_features.append(latent_feature)
421
+ global_vq_loss += vq_loss
422
+ global_entropy_loss += entropy_loss
423
+ global_vocab_usage += vocab_usage
424
+ latent_features = torch.cat(latent_features, dim=-1)
425
+ global_entropy_loss /= self.num_codebooks
426
+ global_vq_loss /= self.num_codebooks
427
+ global_vocab_usage /= self.num_codebooks
428
+ return latent_features, global_vq_loss, global_entropy_loss, global_vocab_usage
429
+
430
+
431
+ class CausalAttention(nn.Module):
432
+ def __init__(self, in_dim, out_dim, num_heads):
433
+ super().__init__()
434
+ if in_dim > out_dim:
435
+ # assert in_dim // num_heads == out_dim
436
+ self.head_dim = in_dim // num_heads
437
+ self.qkv = nn.Linear(in_dim, in_dim * 3, bias=False)
438
+ self.q_bias = nn.Parameter(torch.zeros(in_dim))
439
+ self.v_bias = nn.Parameter(torch.zeros(in_dim))
440
+ self.register_buffer('zero_k_bias', torch.zeros(in_dim))
441
+ else:
442
+ # assert out_dim // num_heads == in_dim
443
+ self.head_dim = out_dim // num_heads
444
+ self.qkv = nn.Linear(in_dim, out_dim * 3, bias=False)
445
+ self.q_bias = nn.Parameter(torch.zeros(out_dim))
446
+ self.v_bias = nn.Parameter(torch.zeros(out_dim))
447
+ self.register_buffer('zero_k_bias', torch.zeros(out_dim))
448
+
449
+ self.in_dim = in_dim
450
+ self.out_dim = out_dim
451
+ self.num_heads = num_heads
452
+ self.scale = self.head_dim ** -0.5
453
+ self.proj = nn.Linear(out_dim, out_dim)
454
+
455
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
456
+ B, N, C = x.shape
457
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias)))
458
+ q, k, v = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
459
+
460
+ x = scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0., is_causal=True)
461
+
462
+ if self.in_dim > self.out_dim:
463
+ x = torch.mean(x, dim=1)
464
+ if self.in_dim // self.num_heads != self.out_dim:
465
+ x = nn.functional.adaptive_avg_pool1d(x, self.out_dim)
466
+ else:
467
+ x = x.transpose(1, 2).reshape(B, N, -1)
468
+ x = self.proj(x)
469
+ return x
470
+
471
+
472
+ class AttnProjection(nn.Module):
473
+ def __init__(self, in_dim, out_dim, num_heads, norm_layer=nn.LayerNorm, mlp_ratio=2):
474
+ super().__init__()
475
+ assert out_dim % in_dim == 0 or in_dim % out_dim == 0
476
+ self.in_dim = in_dim
477
+ self.out_dim = out_dim
478
+ self.norm1 = norm_layer(in_dim)
479
+ self.attn = CausalAttention(in_dim, out_dim, num_heads)
480
+ self.proj = nn.Linear(in_dim, out_dim)
481
+ self.norm3 = norm_layer(in_dim)
482
+
483
+ self.norm2 = norm_layer(out_dim)
484
+ hidden_dim = int(out_dim * mlp_ratio)
485
+ self.mlp = GeGluMlp(
486
+ in_features=out_dim,
487
+ hidden_features=hidden_dim
488
+ )
489
+
490
+ def forward(self, x):
491
+ x = self.proj(self.norm3(x)) + self.attn(self.norm1(x))
492
+ x = x + self.mlp(self.norm2(x))
493
+ return x
494
+
495
+
496
+ from functools import partial
497
+ from timm.models.layers import create_conv2d, get_norm_act_layer, get_norm_layer, make_divisible
498
+
499
+ class GeGluMlp(nn.Module):
500
+ def __init__(
501
+ self,
502
+ in_features,
503
+ hidden_features,
504
+ act_layer = None,
505
+ drop = 0.0,
506
+ ):
507
+ super().__init__()
508
+ norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
509
+ self.norm = norm_layer(in_features)
510
+ self.act = nn.GELU(approximate='tanh')
511
+ self.w0 = nn.Linear(in_features, hidden_features)
512
+ self.w1 = nn.Linear(in_features, hidden_features)
513
+ self.w2 = nn.Linear(hidden_features, in_features)
514
+
515
+ def forward(self, x):
516
+ x = self.norm(x)
517
+ x = self.act(self.w0(x)) * self.w1(x)
518
+ x = self.w2(x)
519
+ return x
t2i.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import argparse
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from torchvision import transforms
8
+ from torch.nn import functional as F
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+
11
+ from model import *
12
+ from unitok.config import Args
13
+ from unitok.model import UniTok
14
+
15
+
16
+ PILtransform = transforms.ToPILImage()
17
+
18
+
19
+ def top_k_top_p_filtering(
20
+ logits,
21
+ top_k: int = 0,
22
+ top_p: float = 1.0,
23
+ filter_value: float = -float("Inf"),
24
+ min_tokens_to_keep: int = 1,
25
+ ):
26
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
27
+ Args:
28
+ logits: logits distribution shape (batch size, vocabulary size)
29
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
30
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
31
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
32
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
33
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
34
+ """
35
+
36
+ if top_k > 0:
37
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
38
+ # Remove all tokens with a probability less than the last token of the top-k
39
+
40
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
41
+ logits[indices_to_remove] = filter_value
42
+
43
+ if top_p < 1.0:
44
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
45
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
46
+
47
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
48
+ sorted_indices_to_remove = cumulative_probs > top_p
49
+ if min_tokens_to_keep > 1:
50
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
51
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
52
+ # Shift the indices to the right to keep also the first token above the threshold
53
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
54
+ sorted_indices_to_remove[..., 0] = 0
55
+
56
+ # scatter sorted tensors to original indexing
57
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
58
+ logits[indices_to_remove] = filter_value
59
+ # import pdb;pdb.set_trace()
60
+ return logits
61
+
62
+
63
+ def sample(logits, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0, sample_logits=True):
64
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
65
+ if top_k > 0 or top_p < 1.0:
66
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
67
+ probs = F.softmax(logits, dim=-1)
68
+ if sample_logits:
69
+ idx = torch.multinomial(probs, num_samples=1)
70
+ else:
71
+ _, idx = torch.topk(probs, k=1, dim=-1)
72
+ return idx, probs
73
+
74
+
75
+ def split_list(input_list, chunk_size):
76
+ return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
77
+
78
+
79
+ def get_args_parser():
80
+ parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
81
+ parser.add_argument('--unitok_path', type=str, required=True)
82
+ parser.add_argument('--mllm_path', type=str, required=True)
83
+ parser.add_argument('--prompt_file', type=str, required=True)
84
+ parser.add_argument('--result_dir', type=str, required=True)
85
+ parser.add_argument('--idx', type=int, default=0)
86
+ parser.add_argument('--tau', type=float, default=0.9)
87
+ parser.add_argument('--topk', type=int, default=2048)
88
+ parser.add_argument('--topp', type=float, default=0.96)
89
+ parser.add_argument('--cfg_scale', type=float, default=5.0)
90
+ return parser
91
+
92
+
93
+ def main(args):
94
+ text_set_id = args.idx
95
+ tau = args.tau
96
+ topk = args.topk
97
+ topp = args.topp
98
+ cfg_scale = args.cfg_scale
99
+
100
+ print('loading vq model ...')
101
+ ckpt = torch.load(args.unitok_path, map_location='cpu')
102
+ vae_cfg = Args()
103
+ vae_cfg.load_state_dict(ckpt['args'])
104
+ vq_model = UniTok(vae_cfg)
105
+ vq_model.load_state_dict(ckpt['trainer']['unitok'])
106
+ vq_model.to('cuda')
107
+ vq_model.eval()
108
+
109
+ image_save_pth = '{}/GenAI-cfg_{}-topk_{}-topp_{}-tau_{}'.format(args.result_dir, str(cfg_scale), str(topk), str(topp), str(tau))
110
+
111
+ tokenizer = AutoTokenizer.from_pretrained(args.mllm_path, padding_side='left')
112
+ vqllm = AutoModelForCausalLM.from_pretrained(
113
+ args.mllm_path,
114
+ attn_implementation='flash_attention_2',
115
+ torch_dtype=torch.bfloat16
116
+ ).to('cuda')
117
+
118
+ num_processes = 8
119
+ chunk_size = 8 # batchsize
120
+ num_codebooks = vae_cfg.num_codebooks
121
+
122
+ with open(args.prompt_file, 'r') as f:
123
+ lines = f.readlines()
124
+ all_prompts = []
125
+ for index, line in enumerate(lines):
126
+ all_prompts.append({'Index': str(index + 1).zfill(5), 'Prompt': line.strip()})
127
+
128
+ chunked_filenames = np.array_split(all_prompts, num_processes)
129
+ subset = chunked_filenames[text_set_id].tolist()
130
+ chunk_inputs = split_list(subset, chunk_size)
131
+ for chunk in tqdm(chunk_inputs):
132
+ text_inputs = [v['Prompt'] for v in chunk]
133
+ uncondition_text_inputs = ['<unconditional>'] * len(text_inputs)
134
+ for i in range(len(text_inputs)):
135
+ text_inputs[i] = text_inputs[i] + ' Generate an image based on this description.'
136
+ ori_batchsize = len(text_inputs)
137
+
138
+ save_list = []
139
+ if cfg_scale > 1:
140
+ model_inputs = tokenizer(text_inputs + uncondition_text_inputs, return_tensors="pt", padding=True).to('cuda')
141
+ total_batchsize = len(text_inputs + uncondition_text_inputs)
142
+ model_inputs['input_ids'] = torch.cat([
143
+ model_inputs['input_ids'],
144
+ torch.empty(total_batchsize, 1).fill_(3).to(model_inputs['input_ids'])
145
+ ], dim=1)
146
+ model_inputs['attention_mask'] = torch.cat([
147
+ model_inputs['attention_mask'],
148
+ torch.empty(total_batchsize, 1).fill_(1).to(model_inputs['attention_mask'])
149
+ ], dim=1)
150
+ else:
151
+ model_inputs = tokenizer(text_inputs, return_tensors="pt", padding=True).to('cuda')
152
+ total_batchsize = len(text_inputs)
153
+ model_inputs['input_ids'] = torch.cat([
154
+ model_inputs['input_ids'],
155
+ torch.empty(total_batchsize, 1).fill_(3).to(model_inputs['input_ids'])
156
+ ], dim=1)
157
+ model_inputs['attention_mask'] = torch.cat([
158
+ model_inputs['attention_mask'],
159
+ torch.empty(total_batchsize, 1).fill_(1).to(model_inputs['attention_mask'])
160
+ ], dim=1)
161
+ with torch.no_grad():
162
+ sampling_kwargs = {'temperature': tau, 'top_k': topk, 'top_p': topp, 'sample_logits': True}
163
+ pred_tokens = []
164
+ input_multi_ids = None
165
+ for _ in range(256):
166
+ outputs = vqllm.T2I_forward_nocache(
167
+ **model_inputs,
168
+ input_multi_ids=input_multi_ids,
169
+ use_cache=None,
170
+ return_dict=True,
171
+ output_attentions=False,
172
+ output_hidden_states=False,
173
+ )
174
+ next_embed = outputs['last_hidden_state'][:, -1:, :]
175
+
176
+ indices_arhead = []
177
+ for i_head in range(num_codebooks):
178
+ ar_next_embed = vqllm.ar_head(
179
+ inputs_embeds=next_embed,
180
+ use_cache=False,
181
+ output_attentions=False,
182
+ output_hidden_states=False,
183
+ return_dict=False,
184
+ )
185
+ next_token_logits = vqllm.ar_head.linear_head(ar_next_embed)
186
+ if cfg_scale > 1:
187
+ cond_logits, uncond_logits = torch.split(next_token_logits, len(next_token_logits) // 2, dim=0)
188
+ cfg_logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
189
+ half_next_token, _ = sample(cfg_logits, **sampling_kwargs)
190
+ next_token = torch.cat([half_next_token, half_next_token]) # [bz,1]
191
+ else:
192
+ next_token, next_prob = sample(next_token_logits, **sampling_kwargs)
193
+
194
+ indices_arhead.append(next_token)
195
+ if i_head < num_codebooks - 1:
196
+ predicted_embed = vqllm.ar_head.codebooks[i_head](next_token)
197
+ next_embed = torch.cat([next_embed, predicted_embed], dim=1)
198
+
199
+ # update generated ids, model inputs, and length for next step
200
+ pred_tokens.append(torch.cat(indices_arhead, dim=1)) # [numcodebook,bz*2]
201
+ input_multi_ids = torch.stack(pred_tokens, dim=-1)
202
+
203
+ del sampling_kwargs, model_inputs, outputs
204
+
205
+ image_vq_id = torch.stack(pred_tokens, dim=-1)[:ori_batchsize]
206
+ save_list.append(image_vq_id)
207
+
208
+ torch.cuda.empty_cache()
209
+
210
+ print('decoding images ...')
211
+ if not os.path.exists(image_save_pth):
212
+ os.makedirs(image_save_pth)
213
+ for datainfo, vq_code in zip(chunk, save_list[0]):
214
+ idx = datainfo['Index']
215
+ new_gen_ids = vq_code.unsqueeze(0).to('cuda')
216
+ rec_image = vq_model.idx_to_img(new_gen_ids)
217
+ rec_img = PILtransform(rec_image.squeeze(0).add(1).mul_(0.5).clamp_(0, 1))
218
+ rec_img.save('{}/{}.jpg'.format(image_save_pth, str(idx)))
219
+
220
+
221
+ if __name__ == '__main__':
222
+ parser = argparse.ArgumentParser('genai inference script', parents=[get_args_parser()])
223
+ args = parser.parse_args()
224
+ main(args)
tools.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import logging.handlers
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ from constants import LOGDIR
10
+
11
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13
+
14
+ handler = None
15
+
16
+
17
+ def build_logger(logger_name, logger_filename):
18
+ global handler
19
+
20
+ formatter = logging.Formatter(
21
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ )
24
+
25
+ # Set the format of root handlers
26
+ if not logging.getLogger().handlers:
27
+ logging.basicConfig(level=logging.INFO)
28
+ logging.getLogger().handlers[0].setFormatter(formatter)
29
+
30
+ # Redirect stdout and stderr to loggers
31
+ stdout_logger = logging.getLogger("stdout")
32
+ stdout_logger.setLevel(logging.INFO)
33
+ sl = StreamToLogger(stdout_logger, logging.INFO)
34
+ sys.stdout = sl
35
+
36
+ stderr_logger = logging.getLogger("stderr")
37
+ stderr_logger.setLevel(logging.ERROR)
38
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
39
+ sys.stderr = sl
40
+
41
+ # Get logger
42
+ logger = logging.getLogger(logger_name)
43
+ logger.setLevel(logging.INFO)
44
+
45
+ # Add a file handler for all loggers
46
+ if handler is None:
47
+ os.makedirs(LOGDIR, exist_ok=True)
48
+ filename = os.path.join(LOGDIR, logger_filename)
49
+ handler = logging.handlers.TimedRotatingFileHandler(
50
+ filename, when='D', utc=True, encoding='UTF-8')
51
+ handler.setFormatter(formatter)
52
+
53
+ for name, item in logging.root.manager.loggerDict.items():
54
+ if isinstance(item, logging.Logger):
55
+ item.addHandler(handler)
56
+
57
+ return logger
58
+
59
+
60
+ class StreamToLogger(object):
61
+ """
62
+ Fake file-like stream object that redirects writes to a logger instance.
63
+ """
64
+ def __init__(self, logger, log_level=logging.INFO):
65
+ self.terminal = sys.stdout
66
+ self.logger = logger
67
+ self.log_level = log_level
68
+ self.linebuf = ''
69
+
70
+ def __getattr__(self, attr):
71
+ return getattr(self.terminal, attr)
72
+
73
+ def write(self, buf):
74
+ temp_linebuf = self.linebuf + buf
75
+ self.linebuf = ''
76
+ for line in temp_linebuf.splitlines(True):
77
+ # From the io.TextIOWrapper docs:
78
+ # On output, if newline is None, any '\n' characters written
79
+ # are translated to the system default line separator.
80
+ # By default sys.stdout.write() expects '\n' newlines and then
81
+ # translates them so this is still cross platform.
82
+ if line[-1] == '\n':
83
+ self.logger.log(self.log_level, line.rstrip())
84
+ else:
85
+ self.linebuf += line
86
+
87
+ def flush(self):
88
+ if self.linebuf != '':
89
+ self.logger.log(self.log_level, self.linebuf.rstrip())
90
+ self.linebuf = ''
91
+
92
+
93
+ def disable_torch_init():
94
+ """
95
+ Disable the redundant torch default initialization to accelerate model creation.
96
+ """
97
+ import torch
98
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100
+
101
+
102
+ def violates_moderation(text):
103
+ """
104
+ Check whether the text violates OpenAI moderation API.
105
+ """
106
+ url = "https://api.openai.com/v1/moderations"
107
+ headers = {"Content-Type": "application/json",
108
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109
+ text = text.replace("\n", "")
110
+ data = "{" + '"input": ' + f'"{text}"' + "}"
111
+ data = data.encode("utf-8")
112
+ try:
113
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
114
+ flagged = ret.json()["results"][0]["flagged"]
115
+ except requests.exceptions.RequestException as e:
116
+ flagged = False
117
+ except KeyError as e:
118
+ flagged = False
119
+
120
+ return flagged
121
+
122
+
123
+ def pretty_print_semaphore(semaphore):
124
+ if semaphore is None:
125
+ return "None"
126
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
unitok/config.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ from tap import Tap
7
+ from typing import Optional, Union
8
+ from collections import OrderedDict
9
+
10
+ from unitok import dist
11
+
12
+
13
+ class Args(Tap):
14
+ model: str = 'vitamin_large' # 'vitamin_base', 'vitamin_large', xxx
15
+ exp_name: str = 'unitok_large'
16
+ output_dir: str = 'local_output'
17
+ resume_from: str = '' # if specified, load this checkpoint; if not, load the latest checkpoint in output_dir (if exists)
18
+ lpips_path: str = 'external/lpips_with_vgg.pth'
19
+ dino_path: str = 'external/dinov2_vits14_pretrain.pth'
20
+ fid_eval_src: str = ''
21
+ fid_eval_dst: str = ''
22
+ vis_img_dir: str = 'asset/vis_imgs/'
23
+ fid_feature_extractor: str = 'external/weights-inception-2015-12-05-6726825d.pth'
24
+ clip_pretrain_path: str = ''
25
+
26
+ # speed-up
27
+ fp16: bool = False # whether to use FP16
28
+ bf16: bool = True # whether to use BF16
29
+ tf32: bool = True # whether to use TensorFloat32
30
+ compile_model: bool = False # whether to use torch.compile()
31
+ ddp_static: bool = False # whether to use static graph in DDP
32
+ grad_ckpt: bool = True # gradient checkpointing
33
+ grad_accu: int = 1 # gradient accumulation
34
+ device: str = 'cpu' # will be set automatically
35
+ dtype: torch.dtype = torch.float32 # will be set automatically
36
+
37
+ # data
38
+ train_data: str = None
39
+ val_data: str = None
40
+ dataset_type: str = 'webdataset'
41
+ imagenet_val: str = None
42
+ imagenet_v2: str = None
43
+ subset_ratio: float = 1.0
44
+ img_size: int = 256
45
+ resize_ratio: float = 1.125 # only applicable to 'img' dataset_type
46
+ hflip: bool = False
47
+ workers: int = 8 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
48
+ train_num_samples: int = 1280_000_000
49
+ train_data_upsampling_factors: str = None
50
+ dataset_resampled: bool = False
51
+ use_aug: bool = False
52
+
53
+ # quantizer
54
+ vocab_size: int = 32768
55
+ vocab_width: int = 64
56
+ vocab_norm: bool = True
57
+ vq_beta: float = 0.25 # commitment loss weight
58
+ num_codebooks: int = 8
59
+ quant_proj: str = 'attn'
60
+
61
+ # model
62
+ embed_dim: int = 768
63
+ num_query: int = 0
64
+ use_clip_pretrain: bool = False
65
+ patch_size: int = 16
66
+ drop_path: float = 0.1
67
+ text_width: int = 768
68
+ text_heads: int = 12
69
+ text_layers: int = 12
70
+ text_vocab_size: int = 49408
71
+ text_context_length: int = 77
72
+
73
+ # CLIP
74
+ local_loss: bool = True
75
+ gather_with_grad: bool = True
76
+ pretrained_clip: str = None
77
+ pretrained_clip_text: str = None
78
+ lock_text: bool = False
79
+ lock_text_unlocked_layers: int = 0
80
+ lock_text_freeze_layer_norm: bool = False
81
+ force_custom_text: bool = False
82
+ force_custom_vision: bool = False
83
+ zeroshot_eval_freq: int = 1
84
+
85
+ # discriminator
86
+ dino_depth: int = 12
87
+ dino_kernel_size: int = 9
88
+ disc_norm: str = 'gn' # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm
89
+ disc_aug_prob: float = 1.0
90
+ disc_specnorm: bool = False
91
+ step_disc_every: int = 1
92
+
93
+ # initialization
94
+ vae_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
95
+ vocab_init: float = -1 # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init)
96
+ disc_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init)
97
+
98
+ # optimization
99
+ epoch: int = 1 # number of epochs
100
+ local_bs: int = 64 # batch size per device; if this is specified, --global_bs will be ignored
101
+ vae_local_bs: int = 64 # sub-batch size for vae loss calculation
102
+ global_bs: int = 0 # global batch size (exclusive to --local_bs)
103
+ lr: float = 5e-4 # learning rate
104
+ wd: float = 0.02 # weight decay
105
+ disc_lr: float = 2e-5 # disc lr
106
+ disc_wd: float = 0.2
107
+ grad_clip: float = 10 # <=0 for not using grad clip
108
+ ema: float = 0.9999 # ema ratio
109
+ warmup_iter: int = None
110
+ warmup_ep: float = 0.01 # lr warmup: epochs
111
+ disc_start_ep: float = 0.375 # start using disc loss for VAE after xxx epochs;
112
+ disc_warmup_ep: float = 0.03 # disc loss warm up epochs;
113
+ schedule: str = 'cos' # lr schedule type
114
+ lr_start_ratio: float = 0. # lr warmup: initial lr ratio
115
+ lr_end_ratio: float = 0.1 # lr schedule: final lr ratio
116
+ disc_lr_end_ratio: float = 0.1
117
+ custom_lr_multiplier: float = None
118
+ optimizer: str = 'adamw'
119
+ optim_eps: float = 1e-6
120
+ fuse_opt: bool = False # whether to use fused optimizer
121
+ optim_beta: str = '0.9_0.95' # beta1, beta2 of optimizer
122
+ disc_optim_beta: str = '0.5_0.9' # beta1, beta2 of disc optimizer
123
+
124
+ # loss
125
+ l1: float = 0.2 # L1 rec loss weight
126
+ l2: float = 1.0 # L2 rec loss weight
127
+ lp: float = 1.0 # lpips loss weight
128
+ lpr: int = 48 # only calculate lpips >= this image resolution
129
+ ld: float = 0.4 # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT
130
+ le: float = 0.0 # VQ entropy loss weight
131
+ lq: float = 1.0
132
+ lc: float = 1.0 # CLIP loss weight
133
+ e_temp: float = 0.01
134
+ gada: int = 1
135
+ bcr: float = 4. # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0
136
+ bcr_cut: float = 0.2 # cutout ratio (0.5: 50% width)
137
+ dcrit: str = 'hg' # hg hinge, sp softplus, ln linear
138
+
139
+ # wandb log
140
+ report_wandb: bool = True
141
+ wandb_notes: str = None
142
+ run_id: str = None
143
+
144
+ # debug
145
+ eval_per_epoch: int = 8
146
+ dbg_unused_param: bool = False
147
+ dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
148
+ seed: int = None
149
+ deterministic: bool = False
150
+ same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
151
+
152
+
153
+ def seed_everything(self):
154
+ torch.backends.cudnn.enabled = True
155
+ torch.backends.cudnn.benchmark = True
156
+ torch.backends.cudnn.deterministic = False
157
+ if self.seed is not None:
158
+ if self.deterministic:
159
+ torch.backends.cudnn.benchmark = False
160
+ torch.backends.cudnn.deterministic = True
161
+ torch.use_deterministic_algorithms(True)
162
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'
163
+ seed = self.seed + dist.get_rank() * 10000
164
+ os.environ['PYTHONHASHSEED'] = str(seed)
165
+ random.seed(seed)
166
+ np.random.seed(seed)
167
+ torch.manual_seed(seed)
168
+ torch.cuda.manual_seed(seed)
169
+ torch.cuda.manual_seed_all(seed)
170
+
171
+ def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
172
+ if self.seed is None:
173
+ return None
174
+ g = torch.Generator()
175
+ g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
176
+ return g
177
+
178
+ def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
179
+ d = (OrderedDict if key_ordered else dict)()
180
+ for k in self.class_variables.keys():
181
+ if k not in {'device'}: # these are not serializable
182
+ d[k] = getattr(self, k)
183
+ return d
184
+
185
+ def load_state_dict(self, state_dict):
186
+ for k, v in state_dict.items():
187
+ try:
188
+ setattr(self, k, v)
189
+ except Exception as e:
190
+ print(f'k={k}, v={v}')
191
+ raise e
192
+
193
+ @staticmethod
194
+ def set_tf32(tf32: bool):
195
+ if torch.cuda.is_available():
196
+ torch.backends.cudnn.allow_tf32 = bool(tf32)
197
+ torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
198
+ if hasattr(torch, 'set_float32_matmul_precision'):
199
+ torch.set_float32_matmul_precision('high' if tf32 else 'highest')
200
+ print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
201
+ print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
202
+ print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
203
+
204
+ def __str__(self):
205
+ s = []
206
+ for k in self.class_variables.keys():
207
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
208
+ s.append(f' {k:20s}: {getattr(self, k)}')
209
+ s = '\n'.join(s)
210
+ return f'{{\n{s}\n}}\n'
211
+
212
+
213
+ def init_dist_and_get_args():
214
+ for i in range(len(sys.argv)):
215
+ if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
216
+ del sys.argv[i]
217
+ break
218
+
219
+ args = Args(explicit_bool=True).parse_args(known_only=True)
220
+ # warn args.extra_args
221
+ if len(args.extra_args) > 0:
222
+ print(f'======================================================================================')
223
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
224
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
225
+ print(f'======================================================================================\n\n')
226
+
227
+ # init torch distributed
228
+ os.makedirs(args.output_dir, exist_ok=True)
229
+ dist.init_distributed_mode(local_out_path=args.output_dir, timeout_minutes=30)
230
+
231
+ # set env
232
+ args.set_tf32(args.tf32)
233
+ args.seed_everything()
234
+ args.device = dist.get_device()
235
+
236
+ # update args
237
+ if args.local_bs == 0:
238
+ args.local_bs = max(1, round(args.global_bs / args.grad_accu / dist.get_world_size()))
239
+ args.global_bs = args.local_bs * dist.get_world_size()
240
+ if args.fp16 or args.bf16:
241
+ args.dtype = torch.float16 if args.fp16 else torch.bfloat16
242
+
243
+ return args
unitok/dist.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import pytz
9
+ import torch
10
+ import torch.distributed as tdist
11
+ import torch.multiprocessing as mp
12
+
13
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ __rank_str_zfill = '0'
15
+ __initialized = False
16
+
17
+
18
+ def initialized():
19
+ return __initialized
20
+
21
+
22
+ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
23
+ global __device
24
+ if not torch.cuda.is_available():
25
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
26
+ return
27
+ elif 'RANK' not in os.environ:
28
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
29
+ __device = torch.empty(1).cuda().device
30
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
31
+ return
32
+ # then 'RANK' must exist
33
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
34
+ local_rank = global_rank % num_gpus
35
+ torch.cuda.set_device(local_rank)
36
+
37
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
38
+ if mp.get_start_method(allow_none=True) is None:
39
+ method = 'fork' if fork else 'spawn'
40
+ print(f'[dist initialize] mp method={method}')
41
+ mp.set_start_method(method)
42
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
43
+
44
+ global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
45
+ __local_rank = local_rank
46
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
47
+ __rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
48
+ __device = torch.empty(1).cuda().device
49
+ __initialized = True
50
+
51
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
52
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
53
+
54
+
55
+ def get_rank():
56
+ return __rank
57
+
58
+
59
+ def get_rank_str_zfill():
60
+ return __rank_str_zfill
61
+
62
+
63
+ def get_local_rank():
64
+ return __local_rank
65
+
66
+
67
+ def get_world_size():
68
+ return __world_size
69
+
70
+
71
+ def get_device():
72
+ return __device
73
+
74
+
75
+ def set_gpu_id(gpu_id: int):
76
+ if gpu_id is None: return
77
+ global __device
78
+ if isinstance(gpu_id, (str, int)):
79
+ torch.cuda.set_device(int(gpu_id))
80
+ __device = torch.empty(1).cuda().device
81
+ else:
82
+ raise NotImplementedError
83
+
84
+
85
+ def is_master():
86
+ return __rank == 0
87
+
88
+
89
+ def is_local_master():
90
+ return __local_rank == 0
91
+
92
+
93
+ def new_group(ranks: List[int]):
94
+ if __initialized:
95
+ return tdist.new_group(ranks=ranks)
96
+ return None
97
+
98
+
99
+ def new_local_machine_group():
100
+ if __initialized:
101
+ cur_subgroup, subgroups = tdist.new_subgroups()
102
+ return cur_subgroup
103
+ return None
104
+
105
+
106
+ def barrier():
107
+ if __initialized:
108
+ tdist.barrier()
109
+
110
+
111
+ def allreduce(t: torch.Tensor, async_op=False):
112
+ if __initialized:
113
+ if not t.is_cuda:
114
+ cu = t.detach().cuda()
115
+ ret = tdist.all_reduce(cu, async_op=async_op)
116
+ t.copy_(cu.cpu())
117
+ else:
118
+ ret = tdist.all_reduce(t, async_op=async_op)
119
+ return ret
120
+ return None
121
+
122
+
123
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
124
+ if __initialized:
125
+ if not t.is_cuda:
126
+ t = t.cuda()
127
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
128
+ tdist.all_gather(ls, t)
129
+ else:
130
+ ls = [t]
131
+ if cat:
132
+ ls = torch.cat(ls, dim=0)
133
+ return ls
134
+
135
+
136
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
137
+ if __initialized:
138
+ if not t.is_cuda:
139
+ t = t.cuda()
140
+
141
+ t_size = torch.tensor(t.size(), device=t.device)
142
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
143
+ tdist.all_gather(ls_size, t_size)
144
+
145
+ max_B = max(size[0].item() for size in ls_size)
146
+ pad = max_B - t_size[0].item()
147
+ if pad:
148
+ pad_size = (pad, *t.size()[1:])
149
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
150
+
151
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
152
+ tdist.all_gather(ls_padded, t)
153
+ ls = []
154
+ for t, size in zip(ls_padded, ls_size):
155
+ ls.append(t[:size[0].item()])
156
+ else:
157
+ ls = [t]
158
+ if cat:
159
+ ls = torch.cat(ls, dim=0)
160
+ return ls
161
+
162
+
163
+ def broadcast(t: torch.Tensor, src_rank) -> None:
164
+ if __initialized:
165
+ if not t.is_cuda:
166
+ cu = t.detach().cuda()
167
+ tdist.broadcast(cu, src=src_rank)
168
+ t.copy_(cu.cpu())
169
+ else:
170
+ tdist.broadcast(t, src=src_rank)
171
+
172
+
173
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
174
+ if not initialized():
175
+ return torch.tensor([val]) if fmt is None else [fmt % val]
176
+
177
+ ts = torch.zeros(__world_size)
178
+ ts[__rank] = val
179
+ allreduce(ts)
180
+ if fmt is None:
181
+ return ts
182
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
183
+
184
+
185
+ def master_only(func):
186
+ @functools.wraps(func)
187
+ def wrapper(*args, **kwargs):
188
+ force = kwargs.pop('force', False)
189
+ if force or is_master():
190
+ ret = func(*args, **kwargs)
191
+ else:
192
+ ret = None
193
+ barrier()
194
+ return ret
195
+ return wrapper
196
+
197
+
198
+ def local_master_only(func):
199
+ @functools.wraps(func)
200
+ def wrapper(*args, **kwargs):
201
+ force = kwargs.pop('force', False)
202
+ if force or is_local_master():
203
+ ret = func(*args, **kwargs)
204
+ else:
205
+ ret = None
206
+ barrier()
207
+ return ret
208
+ return wrapper
209
+
210
+
211
+ def for_visualize(func):
212
+ @functools.wraps(func)
213
+ def wrapper(*args, **kwargs):
214
+ if is_master():
215
+ # with torch.no_grad():
216
+ ret = func(*args, **kwargs)
217
+ else:
218
+ ret = None
219
+ return ret
220
+ return wrapper
221
+
222
+
223
+ def finalize():
224
+ if __initialized:
225
+ tdist.destroy_process_group()
226
+
227
+
228
+ def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30):
229
+ try:
230
+ __initialize(fork=False, timeout_minutes=timeout_minutes)
231
+ barrier()
232
+ except RuntimeError as e:
233
+ print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
234
+ raise e
235
+
236
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
237
+ _change_builtin_print(is_local_master())
238
+ if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
239
+ sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)
240
+
241
+
242
+ def _change_builtin_print(is_master):
243
+ import builtins as __builtin__
244
+
245
+ builtin_print = __builtin__.print
246
+ if type(builtin_print) != type(open):
247
+ return
248
+
249
+ def prt(*args, **kwargs):
250
+ force = kwargs.pop('force', False)
251
+ clean = kwargs.pop('clean', False)
252
+ deeper = kwargs.pop('deeper', False)
253
+ if is_master or force:
254
+ if not clean:
255
+ f_back = sys._getframe().f_back
256
+ if deeper and f_back.f_back is not None:
257
+ f_back = f_back.f_back
258
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
259
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
260
+ builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
261
+ else:
262
+ builtin_print(*args, **kwargs)
263
+
264
+ __builtin__.print = prt
265
+
266
+
267
+ class BackupStreamToFile(object):
268
+ def __init__(self, local_output_dir, for_stdout=True):
269
+ self.for_stdout = for_stdout
270
+ self.terminal_stream = sys.stdout if for_stdout else sys.stderr
271
+ fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt')
272
+ existing = os.path.exists(fname)
273
+ self.file_stream = open(fname, 'a')
274
+ if existing:
275
+ time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
276
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
277
+ self.file_stream.flush()
278
+ self.enabled = True
279
+
280
+ def write(self, message):
281
+ self.terminal_stream.write(message)
282
+ self.file_stream.write(message)
283
+
284
+ def flush(self):
285
+ self.terminal_stream.flush()
286
+ self.file_stream.flush()
287
+
288
+ def close(self):
289
+ if not self.enabled:
290
+ return
291
+ self.enabled = False
292
+ self.file_stream.flush()
293
+ self.file_stream.close()
294
+ if self.for_stdout:
295
+ sys.stdout = self.terminal_stream
296
+ sys.stdout.flush()
297
+ else:
298
+ sys.stderr = self.terminal_stream
299
+ sys.stderr.flush()
300
+
301
+ def __del__(self):
302
+ self.close()
unitok/model.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from contextlib import nullcontext
7
+
8
+ from unitok.vitamin import GeGluMlp, ViTaminDecoder
9
+ from unitok.quant import VectorQuantizerM
10
+ from unitok.vqvae import AttnProjection
11
+
12
+
13
+ class UniTok(nn.Module):
14
+ def __init__(self, args):
15
+ super().__init__()
16
+
17
+ self.num_query = args.num_query
18
+
19
+ self.encoder = timm.create_model(
20
+ args.model,
21
+ patch_size=1,
22
+ fc_norm=False,
23
+ drop_rate=0.0,
24
+ num_classes=0,
25
+ global_pool='',
26
+ pos_embed='none',
27
+ class_token=False,
28
+ mlp_layer=GeGluMlp,
29
+ reg_tokens=args.num_query,
30
+ img_size=args.img_size,
31
+ drop_path_rate=args.drop_path,
32
+ )
33
+ self.encoder.pos_embed = nn.Parameter(torch.zeros(1, 1, self.encoder.embed_dim), requires_grad=False)
34
+
35
+ if args.quant_proj == 'linear':
36
+ self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width)
37
+ elif args.quant_proj == 'attn':
38
+ self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, self.encoder.embed_dim // args.vocab_width)
39
+ else:
40
+ raise NotImplementedError
41
+
42
+ self.quantizer = VectorQuantizerM(
43
+ vocab_size=args.vocab_size,
44
+ vocab_width=args.vocab_width,
45
+ beta=args.vq_beta,
46
+ use_entropy_loss=args.le > 0,
47
+ entropy_temp=args.e_temp,
48
+ num_codebooks=args.num_codebooks,
49
+ )
50
+
51
+ if args.quant_proj == 'linear':
52
+ self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim)
53
+ elif args.quant_proj == 'attn':
54
+ self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, self.encoder.embed_dim // args.vocab_width)
55
+ else:
56
+ raise NotImplementedError
57
+
58
+ self.decoder = ViTaminDecoder(
59
+ args.model,
60
+ num_query=args.num_query,
61
+ img_size=args.img_size,
62
+ drop_path=args.drop_path,
63
+ grad_ckpt=args.grad_ckpt,
64
+ )
65
+
66
+ text_cfg = {
67
+ "width": args.text_width,
68
+ "heads": args.text_heads,
69
+ "layers": args.text_layers,
70
+ "vocab_size": args.text_vocab_size,
71
+ "context_length": args.text_context_length,
72
+ }
73
+ from open_clip.model import _build_text_tower
74
+ self.text_encoder = _build_text_tower(args.embed_dim, text_cfg)
75
+
76
+ self.fc_norm = nn.LayerNorm(self.encoder.embed_dim, eps=1e-6)
77
+ self.projection = nn.Linear(self.encoder.embed_dim, args.embed_dim)
78
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
79
+
80
+ self.context_length = self.text_encoder.context_length
81
+ self.vocab_size = self.text_encoder.vocab_size
82
+ self.maybe_record_function = nullcontext
83
+
84
+ self.text_no_grad = False
85
+ self.encoder.set_grad_checkpointing(args.grad_ckpt)
86
+ self.text_encoder.set_grad_checkpointing(args.grad_ckpt)
87
+
88
+ def forward(self, img, vae_bs, text=None, ret_usages=False):
89
+ img_tokens = self.encoder(img).float()
90
+ with torch.cuda.amp.autocast(enabled=False):
91
+ img_tokens = torch.utils.checkpoint.checkpoint(self.quant_proj, img_tokens, use_reentrant=False)
92
+ img_tokens, vq_loss, entropy_loss, usages = self.quantizer(img_tokens)
93
+ img_tokens = torch.utils.checkpoint.checkpoint(self.post_quant_proj, img_tokens, use_reentrant=False)
94
+ img_rec = self.decoder(img_tokens[:vae_bs]).float()
95
+
96
+ clip_visual = img_tokens.mean(dim=1)
97
+ clip_visual = self.projection(self.fc_norm(clip_visual))
98
+ clip_visual = F.normalize(clip_visual, dim=-1)
99
+ if text is not None:
100
+ clip_text = self.text_encoder(text)
101
+ clip_text = F.normalize(clip_text, dim=-1)
102
+ else:
103
+ clip_text = None
104
+
105
+ output_dict = {
106
+ "img_rec": img_rec,
107
+ "vq_loss": vq_loss,
108
+ "entropy_loss": entropy_loss,
109
+ "codebook_usages": usages,
110
+ "clip_image_features": clip_visual,
111
+ "clip_text_features": clip_text,
112
+ "logit_scale": self.logit_scale.exp()
113
+ }
114
+ return output_dict
115
+
116
+ def encode_image(self, image, normalize: bool = False):
117
+ img_tokens = self.encoder(image)
118
+ img_tokens = self.quant_proj(img_tokens)
119
+ img_indices = self.quantizer.f_to_idx(img_tokens)
120
+ img_tokens = self.quantizer.idx_to_f(img_indices)
121
+ img_tokens = self.post_quant_proj(img_tokens)
122
+ features = img_tokens.mean(dim=1)
123
+ features = self.projection(self.fc_norm(features))
124
+ return F.normalize(features, dim=-1) if normalize else features
125
+
126
+ def encode_text(self, text, normalize: bool = False):
127
+ features = self.text_encoder(text)
128
+ return F.normalize(features, dim=-1) if normalize else features
129
+
130
+ def img_to_idx(self, img):
131
+ features = self.encoder(img).float()
132
+ features = self.quant_proj(features)
133
+ return self.quantizer.f_to_idx(features)
134
+
135
+ def idx_to_img(self, indices):
136
+ features = self.quantizer.idx_to_f(indices)
137
+ features = self.post_quant_proj(features)
138
+ img = self.decoder(features).clamp_(-1, 1)
139
+ return img
140
+
141
+ def img_to_reconstructed_img(self, image) -> torch.Tensor:
142
+ img_tokens = self.encoder(image)
143
+ img_tokens = self.quant_proj(img_tokens)
144
+ img_tokens, _, _, _ = self.quantizer(img_tokens)
145
+ img_tokens = self.post_quant_proj(img_tokens)
146
+ img_rec = self.decoder(img_tokens).clamp_(-1, 1)
147
+ return img_rec
148
+
149
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True, unlock_text_proj=False):
150
+ self.text.lock(unlocked_layers, freeze_layer_norm, unlock_text_proj)
151
+ self.text_no_grad = True
152
+
153
+
154
+ if __name__ == '__main__':
155
+ model = timm.create_model(
156
+ 'vitamin_base',
157
+ patch_size=1,
158
+ fc_norm=True,
159
+ drop_rate=0.0,
160
+ num_classes=0,
161
+ global_pool='',
162
+ pos_embed='none',
163
+ class_token=False,
164
+ mlp_layer=GeGluMlp,
165
+ reg_tokens=0,
166
+ img_size=256,
167
+ drop_path_rate=0.1,
168
+ )
169
+ model.pos_embed = nn.Parameter(torch.zeros(1, 1, model.embed_dim), requires_grad=False)
170
+
171
+ model_dict = model.state_dict()
172
+ ckpt_dict = torch.load('ViTamin-B/pytorch_model.bin')
173
+ visual_dict = dict()
174
+ for k, v in ckpt_dict.items():
175
+ if k.startswith('visual.'):
176
+ if 'head' in k or 'pos_embed' in k:
177
+ continue
178
+ new_k = k.replace('visual.trunk.', '')
179
+ visual_dict[new_k] = v
180
+
181
+ model.load_state_dict(visual_dict, strict=False)
182
+ print(set(model_dict.keys()) - set(visual_dict.keys()))
183
+ print(set(visual_dict.keys() - set(model_dict.keys())))
184
+
unitok/quant.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Tuple
3
+ from torch.nn import functional as F
4
+ from torch import distributed as tdist, nn as nn
5
+
6
+ from unitok import dist
7
+
8
+
9
+ def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
10
+ E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
11
+ E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
12
+ logits = -E_dist.float().mul_(inv_entropy_tau)
13
+ # calc per_sample_entropy
14
+ prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
15
+ per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
16
+ # calc codebook_entropy
17
+ avg_prob = prob.mean(dim=0) # (vocab_size,)
18
+ log_avg_prob = torch.log(avg_prob + 1e-7)
19
+ codebook_entropy = (-avg_prob * log_avg_prob).sum()
20
+ # calc entropy_loss
21
+ entropy_loss = per_sample_entropy - codebook_entropy
22
+ return entropy_loss
23
+
24
+
25
+ class NormalizedEmbedding(nn.Embedding):
26
+ def __init__(self, num_embeddings: int, embedding_dim: int):
27
+ super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
28
+ # self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
29
+
30
+ def forward(self, idx):
31
+ return F.embedding(
32
+ idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
33
+ self.norm_type, self.scale_grad_by_freq, self.sparse
34
+ )
35
+
36
+ def get_norm_weight(self):
37
+ return F.normalize(self.weight, dim=1)
38
+
39
+
40
+ class ResConv(nn.Conv2d):
41
+ def __init__(self, embed_dim, quant_resi):
42
+ ks = 3 if quant_resi < 0 else 1
43
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
44
+ self.resi_ratio = abs(quant_resi)
45
+
46
+ def forward(self, h_BChw):
47
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
48
+
49
+
50
+ class VectorQuantizer(nn.Module):
51
+ def __init__(
52
+ self,
53
+ vocab_size: int,
54
+ vocab_width: int,
55
+ beta: float = 0.25,
56
+ use_entropy_loss=False,
57
+ entropy_temp=0.01,
58
+ ):
59
+ super().__init__()
60
+ self.beta = beta
61
+ self.vocab_size = vocab_size
62
+ self.vocab_width = vocab_width
63
+ self.vocab_usage_record_times: int = 0
64
+ self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
65
+ self.codebook = NormalizedEmbedding(self.vocab_size, self.vocab_width)
66
+
67
+ self.use_entropy_loss = use_entropy_loss
68
+ self.inv_entropy_tau = 1 / entropy_temp
69
+
70
+ def init_vocab(self, eini: float):
71
+ if eini > 0:
72
+ nn.init.trunc_normal_(self.codebook.weight.data, std=eini)
73
+ elif eini < 0:
74
+ base = self.vocab_width ** -0.5
75
+ base /= 36
76
+ self.codebook.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
77
+
78
+ def extra_repr(self) -> str:
79
+ return f'beta={self.beta:g}'
80
+
81
+ def forward(self, features):
82
+ B, L, C = features.shape
83
+ features = features.reshape(-1, C)
84
+ features = F.normalize(features, dim=-1).float()
85
+ codebook_embed = self.codebook.get_norm_weight()
86
+ indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
87
+ entropy_loss = get_entropy_loss(features, codebook_embed, self.inv_entropy_tau) if self.use_entropy_loss else 0
88
+ features_hat = self.codebook(indices)
89
+
90
+ # calc loss
91
+ vq_loss = F.mse_loss(features_hat.detach(), features).mul_(self.beta) + F.mse_loss(features_hat,
92
+ features.detach())
93
+ features_hat = (features_hat.detach() - features.detach()).add_(features)
94
+
95
+ # update vocab_usage
96
+ prob_per_class_is_chosen = indices.bincount(minlength=self.vocab_size).float()
97
+ handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
98
+ self.training and dist.initialized()) else None
99
+ if handler is not None:
100
+ handler.wait()
101
+ prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
102
+ vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
103
+ if self.vocab_usage_record_times == 0:
104
+ self.vocab_usage.copy_(prob_per_class_is_chosen)
105
+ elif self.vocab_usage_record_times < 100:
106
+ self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
107
+ else:
108
+ self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
109
+ self.vocab_usage_record_times += 1
110
+
111
+ return features_hat.view(B, L, C), vq_loss, entropy_loss, vocab_usage
112
+
113
+ def f_to_idx(self, features):
114
+ B, L, C = features.shape
115
+ features = features.reshape(-1, C)
116
+ features = F.normalize(features, dim=-1).float()
117
+ codebook_embed = self.codebook.get_norm_weight().float()
118
+ indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
119
+ return indices.view(B, L)
120
+
121
+
122
+ class VectorQuantizerM(nn.Module):
123
+ def __init__(
124
+ self,
125
+ vocab_size,
126
+ vocab_width,
127
+ beta=0.25,
128
+ use_entropy_loss=False,
129
+ entropy_temp=0.01,
130
+ num_codebooks=16
131
+ ):
132
+ super().__init__()
133
+ self.num_codebooks = num_codebooks
134
+ self.codebooks = nn.ModuleList()
135
+ for _ in range(num_codebooks):
136
+ codebook = VectorQuantizer(
137
+ vocab_size=vocab_size // num_codebooks,
138
+ vocab_width=vocab_width // num_codebooks,
139
+ beta=beta,
140
+ use_entropy_loss=use_entropy_loss,
141
+ entropy_temp=entropy_temp,
142
+ )
143
+ self.codebooks.append(codebook)
144
+
145
+ def init_vocab(self, eini: float):
146
+ for codebook in self.codebooks:
147
+ codebook.init_vocab(eini)
148
+
149
+ def f_to_idx(self, features):
150
+ indices = []
151
+ chunk_size = features.shape[-1] // self.num_codebooks
152
+ splited_features = features.split(chunk_size, dim=-1)
153
+ for i, codebook in enumerate(self.codebooks):
154
+ indices.append(codebook.f_to_idx(splited_features[i]))
155
+ indices = torch.stack(indices, dim=1)
156
+ return indices
157
+
158
+ def idx_to_f(self, indices):
159
+ assert indices.shape[1] == self.num_codebooks
160
+ latent_features = []
161
+ for i, codebook in enumerate(self.codebooks):
162
+ sub_indices = indices[:, i].flatten(start_dim=1)
163
+ latent_feature = codebook.codebook(sub_indices)
164
+ latent_features.append(latent_feature)
165
+ latent_features = torch.cat(latent_features, dim=-1)
166
+ return latent_features
167
+
168
+ def forward(self, features):
169
+ latent_features = []
170
+ global_vq_loss = 0.
171
+ global_entropy_loss = 0.
172
+ global_vocab_usage = 0.
173
+ chunk_size = features.shape[-1] // self.num_codebooks
174
+ splited_features = features.split(chunk_size, dim=-1)
175
+ for i, codebook in enumerate(self.codebooks):
176
+ latent_feature, vq_loss, entropy_loss, vocab_usage = codebook(splited_features[i])
177
+ latent_features.append(latent_feature)
178
+ global_vq_loss += vq_loss
179
+ global_entropy_loss += entropy_loss
180
+ global_vocab_usage += vocab_usage
181
+ latent_features = torch.cat(latent_features, dim=-1)
182
+ global_entropy_loss /= self.num_codebooks
183
+ global_vq_loss /= self.num_codebooks
184
+ global_vocab_usage /= self.num_codebooks
185
+ return latent_features, global_vq_loss, global_entropy_loss, global_vocab_usage
unitok/vitamin.py ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TODO: FIXME:
3
+ /usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py:251: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
4
+ grad.sizes() = [256, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
5
+ bucket_view.sizes() = [256, 1024, 1, 1], strides() = [1024, 1, 1, 1] (Triggered internally at ../torch/csrc/distributed/c10d/reducer.cpp:334.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
6
+ /usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py:251: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
7
+ grad.sizes() = [256, 1024, 1, 1], strides() = [1024, 1, 1024, 1024]
8
+
9
+ """
10
+
11
+ """ ViTamin
12
+
13
+ Paper: Designing Scalable Vison Models in the Vision-Language Era
14
+
15
+ @misc{chen2023designing,
16
+ title={Designing Scalable Vison Models in the Vision-Language Era},
17
+ author={Jieneng Chen and Qihang Yu and Xiaohui Shen and Alan Yuille and Liang-Cheih Chen},
18
+ year={2023},
19
+ archivePrefix={arXiv},
20
+ primaryClass={cs.CV}
21
+ }
22
+
23
+ Based on Apache 2.0 licensed code at https://github.com/ViTamin/ViTamin
24
+
25
+ Modifications and timm support by Jieneng Chen 2023
26
+
27
+ Reference:
28
+ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
29
+ https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer_hybrid.py
30
+ """
31
+
32
+ import math
33
+ from dataclasses import dataclass
34
+ from functools import partial
35
+ import torch.nn.functional as F
36
+ from typing import Optional, Tuple, Union
37
+
38
+ import timm
39
+ import torch
40
+ import torch.nn as nn
41
+ from timm.layers import to_2tuple
42
+ from timm.layers.norm_act import _create_act
43
+ from timm.models._builder import build_model_with_cfg
44
+ from timm.models._manipulate import checkpoint_seq, named_apply
45
+ from timm.models._registry import register_model
46
+ from timm.models.layers import DropPath
47
+ from timm.models.layers import create_conv2d, get_norm_act_layer, get_norm_layer, make_divisible
48
+ from timm.models.vision_transformer import VisionTransformer, checkpoint_filter_fn
49
+ from timm.models.vision_transformer_hybrid import HybridEmbed
50
+ from torch.utils.checkpoint import checkpoint
51
+
52
+ DropPath.__repr__ = lambda self: f'{type(self).__name__}(...)'
53
+
54
+
55
+ @dataclass
56
+ class VitConvCfg:
57
+ expand_ratio: float = 4.0
58
+ expand_output: bool = True # calculate expansion channels from output (vs input chs)
59
+ kernel_size: int = 3
60
+ group_size: int = 1 # 1 == depthwise
61
+ pre_norm_act: bool = False # activation after pre-norm
62
+ stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
63
+ pool_type: str = 'avg2'
64
+ downsample_pool_type: str = 'avg2'
65
+ act_layer: str = 'gelu' # stem & stage 1234
66
+ act_layer1: str = 'gelu' # stage 1234
67
+ act_layer2: str = 'gelu' # stage 1234
68
+ norm_layer: str = ''
69
+ norm_layer_cl: str = ''
70
+ norm_eps: Optional[float] = None
71
+ down_shortcut: Optional[bool] = True
72
+ mlp: str = 'mlp'
73
+
74
+ def __post_init__(self):
75
+ # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
76
+ use_mbconv = True
77
+ if not self.norm_layer:
78
+ self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
79
+ if not self.norm_layer_cl and not use_mbconv:
80
+ self.norm_layer_cl = 'layernorm'
81
+ if self.norm_eps is None:
82
+ self.norm_eps = 1e-5 if use_mbconv else 1e-6
83
+ self.downsample_pool_type = self.downsample_pool_type or self.pool_type
84
+
85
+
86
+ @dataclass
87
+ class VitCfg:
88
+ # embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
89
+ embed_dim: Tuple[Union[int, Tuple[int, ...]], ...] = (96, 192, 384, 768)
90
+ depths: Tuple[Union[int, Tuple[int, ...]], ...] = (2, 3, 5, 2)
91
+ stem_width: int = 64
92
+ conv_cfg: VitConvCfg = None
93
+ weight_init: str = 'vit_eff'
94
+ head_type: str = ""
95
+ stem_type: str = "stem"
96
+ ln2d_permute: bool = True
97
+ # memory_format: str=""
98
+
99
+
100
+ def _init_conv(module, name, scheme=''):
101
+ if isinstance(module, nn.Conv2d):
102
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
103
+ fan_out //= module.groups
104
+ nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
105
+ if module.bias is not None:
106
+ nn.init.zeros_(module.bias)
107
+
108
+
109
+ class Stem(nn.Module):
110
+ def __init__(
111
+ self,
112
+ in_chs: int,
113
+ out_chs: int,
114
+ act_layer: str = 'gelu',
115
+ norm_layer: str = 'layernorm2d',
116
+ norm_eps: float = 1e-6,
117
+ bias: bool = True,
118
+ ):
119
+ super().__init__()
120
+ self.grad_checkpointing=False
121
+ norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
122
+ self.out_chs = out_chs
123
+ self.conv1 = create_conv2d(in_chs, out_chs, 3, stride=2, bias=bias)
124
+ self.norm1 = norm_act_layer(out_chs)
125
+ self.conv2 = create_conv2d(out_chs, out_chs, 3, stride=1, bias=bias)
126
+ named_apply(_init_conv, self)
127
+
128
+ def forward(self, x):
129
+ if self.grad_checkpointing:
130
+ x = checkpoint(self.conv1, x)
131
+ x = self.norm1(x)
132
+ x = checkpoint(self.conv2, x)
133
+ else:
134
+ x = self.conv1(x)
135
+ x = self.norm1(x)
136
+ x = self.conv2(x)
137
+
138
+ return x
139
+
140
+
141
+ class Downsample2d(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim: int,
145
+ dim_out: int,
146
+ pool_type: str = 'avg2',
147
+ bias: bool = True,
148
+ ):
149
+ super().__init__()
150
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)
151
+
152
+ if dim != dim_out:
153
+ self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) # 1x1 conv
154
+ else:
155
+ self.expand = nn.Identity()
156
+
157
+ def forward(self, x):
158
+ x = self.pool(x) # spatial downsample
159
+ x = self.expand(x) # expand chs
160
+ return x
161
+
162
+
163
+ class StridedConv(nn.Module):
164
+ """ downsample 2d as well
165
+ """
166
+ def __init__(
167
+ self,
168
+ kernel_size=3,
169
+ stride=2,
170
+ padding=1,
171
+ in_chans=3,
172
+ embed_dim=768,
173
+ ln2d_permute=True
174
+ ):
175
+ super().__init__()
176
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
177
+ self.permute = ln2d_permute # TODO: disable
178
+ norm_layer = partial(get_norm_layer('layernorm2d'), eps=1e-6)
179
+ self.norm = norm_layer(in_chans) # affine over C
180
+
181
+ def forward(self, x):
182
+ x = self.norm(x)
183
+ x = self.proj(x)
184
+ return x
185
+
186
+
187
+ class MbConvLNBlock(nn.Module):
188
+ """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
189
+ """
190
+ def __init__(
191
+ self,
192
+ in_chs: int,
193
+ out_chs: int,
194
+ stride: int = 1,
195
+ drop_path: float = 0.,
196
+ kernel_size: int = 3,
197
+ norm_layer: str = 'layernorm2d',
198
+ norm_eps: float = 1e-6,
199
+ act_layer: str = 'gelu',
200
+ expand_ratio: float = 4.0,
201
+ ):
202
+ super(MbConvLNBlock, self).__init__()
203
+ self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
204
+ mid_chs = make_divisible(out_chs * expand_ratio)
205
+ prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
206
+
207
+ if stride == 2:
208
+ self.shortcut = Downsample2d(in_chs, out_chs, pool_type='avg', bias=True)
209
+ elif in_chs != out_chs:
210
+ self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
211
+ else:
212
+ self.shortcut = nn.Identity()
213
+
214
+ self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
215
+ self.down = nn.Identity()
216
+ self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
217
+ self.act1 = _create_act(act_layer, inplace=True)
218
+ self.act2 = _create_act(act_layer, inplace=True)
219
+
220
+ self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=stride, dilation=1, groups=mid_chs, bias=True)
221
+ self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
222
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
223
+
224
+ def init_weights(self, scheme=''):
225
+ named_apply(partial(_init_conv, scheme=scheme), self)
226
+
227
+ def forward(self, x):
228
+ shortcut = self.shortcut(x)
229
+
230
+ x = self.pre_norm(x)
231
+ x = self.down(x) # nn.Identity()
232
+
233
+ # 1x1 expansion conv & act
234
+ x = self.conv1_1x1(x)
235
+ x = self.act1(x)
236
+
237
+ # (strided) depthwise 3x3 conv & act
238
+ x = self.conv2_kxk(x)
239
+ x = self.act2(x)
240
+
241
+ # 1x1 linear projection to output width
242
+ x = self.conv3_1x1(x)
243
+ x = self.drop_path(x) + shortcut
244
+
245
+ return x
246
+
247
+
248
+ class MbConvStages(nn.Module):
249
+ """ MobileConv for stage 1 and stage 2 of ViTamin
250
+ """
251
+ def __init__(
252
+ self,
253
+ cfg: VitCfg,
254
+ img_size: Union[int, Tuple[int, int]] = 224, # place holder
255
+ in_chans: int = 3,
256
+ ):
257
+ super().__init__()
258
+ self.grad_checkpointing = False
259
+ self.stem = Stem(
260
+ in_chs=in_chans,
261
+ out_chs=cfg.stem_width,
262
+ )
263
+ stages = []
264
+ self.num_stages = len(cfg.embed_dim)
265
+ for s, dim in enumerate(cfg.embed_dim[:2]): # stage
266
+ blocks = []
267
+ stage_in_chs = cfg.embed_dim[s-1] if s>0 else cfg.stem_width
268
+ for d in range(cfg.depths[s]):
269
+ blocks += [MbConvLNBlock(
270
+ in_chs = stage_in_chs if d==0 else dim,
271
+ out_chs = dim,
272
+ stride = 2 if d == 0 else 1,
273
+ # cfg = cfg.conv_cfg,
274
+ )]
275
+ blocks = nn.Sequential(*blocks)
276
+ stages += [blocks]
277
+
278
+ self.stages = nn.ModuleList(stages)
279
+ self.pool = StridedConv(
280
+ stride=2,
281
+ in_chans=cfg.embed_dim[1],
282
+ embed_dim=cfg.embed_dim[2]
283
+ )
284
+
285
+ def forward(self, x):
286
+ x = self.stem(x)
287
+ if self.grad_checkpointing and not torch.jit.is_scripting():
288
+ for stage in self.stages:
289
+ x = checkpoint_seq(stage, x)
290
+ x = checkpoint(self.pool, x)
291
+ else:
292
+ for stage in self.stages:
293
+ x = stage(x)
294
+ x = self.pool(x)
295
+
296
+ return x
297
+
298
+
299
+ class GeGluMlp(nn.Module):
300
+ def __init__(
301
+ self,
302
+ in_features,
303
+ hidden_features,
304
+ act_layer = None,
305
+ drop = 0.0,
306
+ ):
307
+ super().__init__()
308
+ norm_layer = partial(get_norm_layer('layernorm'), eps=1e-6)
309
+ self.norm = norm_layer(in_features)
310
+ self.act = nn.GELU(approximate='tanh')
311
+ self.w0 = nn.Linear(in_features, hidden_features)
312
+ self.w1 = nn.Linear(in_features, hidden_features)
313
+ self.w2 = nn.Linear(hidden_features, in_features)
314
+
315
+ def forward(self, x):
316
+ x = self.norm(x)
317
+ x = self.act(self.w0(x)) * self.w1(x)
318
+ x = self.w2(x)
319
+ return x
320
+
321
+
322
+ class HybridEmbed(nn.Module):
323
+ """ CNN Feature Map Embedding
324
+ Extract feature map from CNN, flatten, project to embedding dim.
325
+ """
326
+ def __init__(
327
+ self,
328
+ backbone,
329
+ img_size=256,
330
+ patch_size=1,
331
+ feature_size=None,
332
+ in_chans=3,
333
+ embed_dim=1024,
334
+ bias=True,
335
+ dynamic_img_pad=False,
336
+ ):
337
+ super().__init__()
338
+ assert isinstance(backbone, nn.Module)
339
+ img_size = to_2tuple(img_size)
340
+ patch_size = to_2tuple(patch_size)
341
+ self.img_size = img_size
342
+ self.patch_size = patch_size
343
+ self.backbone = backbone
344
+ with torch.no_grad():
345
+ training = backbone.training
346
+ if training:
347
+ backbone.eval()
348
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
349
+ if isinstance(o, (list, tuple)):
350
+ o = o[-1] # last feature if backbone outputs list/tuple of features
351
+ feature_size = o.shape[-2:]
352
+ feature_dim = o.shape[1]
353
+ backbone.train(training)
354
+
355
+ assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
356
+ self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
357
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
358
+ self.proj = nn.Identity()
359
+
360
+ def forward(self, x):
361
+ x = self.backbone(x)
362
+ if isinstance(x, (list, tuple)):
363
+ x = x[-1] # last feature if backbone outputs list/tuple of features
364
+ x = self.proj(x)
365
+ x = x.flatten(2).transpose(1, 2)
366
+ return x
367
+
368
+
369
+ class Upsample2d(nn.Module):
370
+ def __init__(self, dim, dim_out):
371
+ super().__init__()
372
+ self.conv = torch.nn.Conv2d(dim, dim_out, kernel_size=3, stride=1, padding=1)
373
+
374
+ def forward(self, x):
375
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
376
+ x = self.conv(x)
377
+ return x
378
+
379
+
380
+ class InvMbConvLNBlock(nn.Module):
381
+ """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
382
+ """
383
+ def __init__(
384
+ self,
385
+ in_chs: int,
386
+ out_chs: int,
387
+ stride: int = 1,
388
+ drop_path: float = 0.,
389
+ kernel_size: int = 3,
390
+ norm_layer: str = 'layernorm2d',
391
+ norm_eps: float = 1e-6,
392
+ act_layer: str = 'gelu',
393
+ expand_ratio: float = 4.0,
394
+ ):
395
+ super().__init__()
396
+ self.stride, self.in_chs, self.out_chs = stride, in_chs, out_chs
397
+ mid_chs = make_divisible(out_chs * expand_ratio)
398
+ prenorm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
399
+
400
+ if stride == 2:
401
+ self.shortcut = Upsample2d(in_chs, out_chs)
402
+ elif in_chs != out_chs:
403
+ self.shortcut = nn.Conv2d(in_chs, out_chs, 1, bias=True)
404
+ else:
405
+ self.shortcut = nn.Identity()
406
+
407
+ self.pre_norm = prenorm_act_layer(in_chs, apply_act=False)
408
+
409
+ self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=1, bias=True)
410
+ self.act1 = _create_act(act_layer, inplace=True)
411
+ self.act2 = _create_act(act_layer, inplace=True)
412
+
413
+ self.up = Upsample2d(mid_chs, mid_chs) if stride == 2 else nn.Identity()
414
+ self.conv2_kxk = create_conv2d(mid_chs, mid_chs, kernel_size, stride=1, dilation=1, groups=mid_chs, bias=True)
415
+ self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=True)
416
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
417
+
418
+ def init_weights(self, scheme=''):
419
+ named_apply(partial(_init_conv, scheme=scheme), self)
420
+
421
+ def forward(self, x):
422
+ shortcut = self.shortcut(x)
423
+ x = self.pre_norm(x)
424
+
425
+ # 1x1 expansion conv & act
426
+ x = self.conv1_1x1(x)
427
+ x = self.act1(x)
428
+ x = self.up(x)
429
+
430
+ # (strided) depthwise 3x3 conv & act
431
+ x = self.conv2_kxk(x)
432
+ x = self.act2(x)
433
+
434
+ # 1x1 linear projection to output width
435
+ x = self.conv3_1x1(x)
436
+ x = self.drop_path(x) + shortcut
437
+
438
+ return x
439
+
440
+
441
+ class InvStem(nn.Module):
442
+ def __init__(
443
+ self,
444
+ in_chs: int,
445
+ out_chs: int,
446
+ act_layer: str = 'gelu',
447
+ norm_layer: str = 'layernorm2d',
448
+ norm_eps: float = 1e-6,
449
+ bias: bool = True,
450
+ ):
451
+ super().__init__()
452
+ self.grad_checkpointing=False
453
+ norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
454
+ self.out_chs = out_chs
455
+ self.conv1 = Upsample2d(in_chs, in_chs)
456
+ self.norm1 = norm_act_layer(in_chs)
457
+ self.conv2 = create_conv2d(in_chs, out_chs, 3, stride=1, bias=bias)
458
+ named_apply(_init_conv, self)
459
+
460
+ def forward(self, x):
461
+ if self.grad_checkpointing:
462
+ x = checkpoint(self.conv1, x)
463
+ x = self.norm1(x)
464
+ x = checkpoint(self.conv2, x)
465
+ else:
466
+ x = self.conv1(x)
467
+ x = self.norm1(x)
468
+ x = self.conv2(x)
469
+
470
+ return x
471
+
472
+
473
+ class ViTaminDecoder(nn.Module):
474
+ def __init__(
475
+ self,
476
+ model,
477
+ num_query=0,
478
+ img_size=256,
479
+ drop_path=0.,
480
+ depths=(4, 2),
481
+ grad_ckpt=False,
482
+ ):
483
+ super().__init__()
484
+
485
+ self.num_query = num_query
486
+ vit = timm.create_model(
487
+ model,
488
+ fc_norm=False,
489
+ patch_size=1,
490
+ drop_rate=0.0,
491
+ num_classes=0,
492
+ global_pool='',
493
+ pos_embed='none',
494
+ mlp_layer=GeGluMlp,
495
+ class_token=False,
496
+ reg_tokens=num_query,
497
+ img_size=img_size,
498
+ drop_path_rate=drop_path,
499
+ )
500
+ self.blocks = vit.blocks
501
+ self.norm_pre = vit.norm_pre
502
+ self.norm = vit.norm
503
+
504
+ embed_dims = {
505
+ 'vitamin_base': (768, 256, 128),
506
+ 'vitamin_large': (1024, 320, 160)
507
+ }[model]
508
+ self.up_conv1 = Upsample2d(embed_dims[0], embed_dims[1])
509
+ self.up_conv2 = nn.Sequential(*[
510
+ InvMbConvLNBlock(
511
+ in_chs=embed_dims[1],
512
+ out_chs=embed_dims[1],
513
+ stride=2 if d == 0 else 1)
514
+ for d in range(depths[0])]
515
+ )
516
+ self.up_conv3 = nn.Sequential(*[
517
+ InvMbConvLNBlock(
518
+ in_chs=embed_dims[1] if d == 0 else embed_dims[2],
519
+ out_chs=embed_dims[2],
520
+ stride=2 if d == 0 else 1)
521
+ for d in range(depths[1])]
522
+ )
523
+ self.up_conv4 = InvStem(in_chs=embed_dims[2], out_chs=3)
524
+
525
+ self.grad_ckpt = grad_ckpt
526
+
527
+ def get_last_param(self):
528
+ return self.up_conv4.conv2.weight
529
+
530
+ def forward(self, x):
531
+ B, L, C = x.shape
532
+ H = W = int((L-self.num_query) ** 0.5)
533
+ x = self.norm_pre(x)
534
+ if self.grad_ckpt:
535
+ x = checkpoint_seq(self.blocks, x)
536
+ x = x[:, self.num_query:, :]
537
+ x = self.norm(x)
538
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2)
539
+ x = checkpoint(self.up_conv1, x)
540
+ x = checkpoint_seq(self.up_conv2, x)
541
+ x = checkpoint_seq(self.up_conv3, x)
542
+ else:
543
+ x = self.blocks(x)
544
+ x = x[:, self.num_query:, :]
545
+ x = self.norm(x)
546
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2)
547
+ x = self.up_conv1(x)
548
+ x = self.up_conv2(x)
549
+ x = self.up_conv3(x)
550
+ x = self.up_conv4(x)
551
+ return x
552
+
553
+
554
+ def _create_vision_transformer(variant, pretrained=False, grad_ckpt=False, **kwargs) -> VisionTransformer:
555
+ if kwargs.get('features_only', None):
556
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
557
+
558
+ if 'flexi' in variant:
559
+ # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
560
+ # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
561
+ _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
562
+ else:
563
+ _filter_fn = checkpoint_filter_fn
564
+
565
+ return build_model_with_cfg(
566
+ VisionTransformer,
567
+ variant,
568
+ pretrained,
569
+ pretrained_filter_fn=_filter_fn,
570
+ **kwargs,
571
+ )
572
+
573
+
574
+ def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
575
+ embed_layer = partial(HybridEmbed, backbone=backbone)
576
+ kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
577
+ return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
578
+
579
+
580
+ @register_model
581
+ def vitamin_small(pretrained=False, **kwargs) -> VisionTransformer:
582
+ stage_1_2 = MbConvStages(cfg=VitCfg(
583
+ embed_dim=(64, 128, 384),
584
+ depths=(2, 4, 1),
585
+ stem_width=64,
586
+ conv_cfg = VitConvCfg(
587
+ norm_layer='layernorm2d',
588
+ norm_eps=1e-6,
589
+ ),
590
+ head_type='1d',
591
+ ),
592
+ )
593
+ stage3_args = dict(embed_dim=384, depth=14, num_heads=6, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
594
+ model = _create_vision_transformer_hybrid('vitamin_small', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
595
+ return model
596
+
597
+
598
+ @register_model
599
+ def vitamin_base(pretrained=False, **kwargs) -> VisionTransformer:
600
+ stage_1_2 = MbConvStages(cfg=VitCfg(
601
+ embed_dim=(128, 256, 768),
602
+ depths=(2, 4, 1),
603
+ stem_width=128,
604
+ conv_cfg = VitConvCfg(
605
+ norm_layer='layernorm2d',
606
+ norm_eps=1e-6,
607
+ ),
608
+ head_type='1d',
609
+ ),
610
+ )
611
+ stage3_args = dict(embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
612
+ model = _create_vision_transformer_hybrid('vitamin_base', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
613
+ return model
614
+
615
+
616
+ @register_model
617
+ def vitamin_base_256(pretrained=False, **kwargs) -> VisionTransformer:
618
+ stage_1_2 = MbConvStages(cfg=VitCfg(
619
+ embed_dim=(128, 256, 768),
620
+ depths=(2, 4, 1),
621
+ stem_width=128,
622
+ conv_cfg = VitConvCfg(
623
+ norm_layer='layernorm2d',
624
+ norm_eps=1e-6,
625
+ ),
626
+ head_type='1d',
627
+ ),
628
+ )
629
+ stage3_args = dict(img_size=256, embed_dim=768, depth=14, num_heads=12, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
630
+ model = _create_vision_transformer_hybrid('vitamin_base_256', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
631
+ return model
632
+
633
+
634
+ @register_model
635
+ def vitamin_large(pretrained=False, **kwargs) -> VisionTransformer:
636
+ stage_1_2 = MbConvStages(cfg=VitCfg(
637
+ embed_dim=(160, 320, 1024),
638
+ depths=(2, 4, 1),
639
+ stem_width=160,
640
+ conv_cfg = VitConvCfg(
641
+ norm_layer='layernorm2d',
642
+ norm_eps=1e-6,
643
+ ),
644
+ head_type='1d',
645
+ ),
646
+ )
647
+ stage3_args = dict(embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
648
+ model = _create_vision_transformer_hybrid(
649
+ 'vitamin_large', backbone=stage_1_2, pretrained=pretrained, **dict(stage3_args, **kwargs))
650
+ return model
651
+
652
+ # @register_model
653
+ # def vitamin_large_256(pretrained=False, **kwargs) -> VisionTransformer:
654
+ # backbone = MbConvStages(cfg=VitCfg(
655
+ # embed_dim=(160, 320, 1024),
656
+ # depths=(2, 4, 1),
657
+ # stem_width=160,
658
+ # conv_cfg = VitConvCfg(
659
+ # norm_layer='layernorm2d',
660
+ # norm_eps=1e-6,
661
+ # ),
662
+ # head_type='1d',
663
+ # ),
664
+ # )
665
+ # model_args = dict(img_size=256, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
666
+ # model = _create_vision_transformer_hybrid(
667
+ # 'vitamin_large_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
668
+ # return model
669
+
670
+ # @register_model
671
+ # def vitamin_large_336(pretrained=False, **kwargs) -> VisionTransformer:
672
+ # backbone = MbConvStages(cfg=VitCfg(
673
+ # embed_dim=(160, 320, 1024),
674
+ # depths=(2, 4, 1),
675
+ # stem_width=160,
676
+ # conv_cfg = VitConvCfg(
677
+ # norm_layer='layernorm2d',
678
+ # norm_eps=1e-6,
679
+ # ),
680
+ # head_type='1d',
681
+ # ),
682
+ # )
683
+ # model_args = dict(img_size=336, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
684
+ # model = _create_vision_transformer_hybrid(
685
+ # 'vitamin_large_336', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
686
+ # return model
687
+
688
+ # @register_model
689
+ # def vitamin_large_384(pretrained=False, **kwargs) -> VisionTransformer:
690
+ # backbone = MbConvStages(cfg=VitCfg(
691
+ # embed_dim=(160, 320, 1024),
692
+ # depths=(2, 4, 1),
693
+ # stem_width=160,
694
+ # conv_cfg = VitConvCfg(
695
+ # norm_layer='layernorm2d',
696
+ # norm_eps=1e-6,
697
+ # ),
698
+ # head_type='1d',
699
+ # ),
700
+ # )
701
+ # model_args = dict(img_size=384, embed_dim=1024, depth=31, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
702
+ # model = _create_vision_transformer_hybrid(
703
+ # 'vitamin_large_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
704
+ # return model
705
+
706
+ # @register_model
707
+ # def vitamin_xlarge_256(pretrained=False, **kwargs) -> VisionTransformer:
708
+ # backbone = MbConvStages(cfg=VitCfg(
709
+ # embed_dim=(192, 384, 1152),
710
+ # depths=(2, 4, 1),
711
+ # stem_width=192,
712
+ # conv_cfg = VitConvCfg(
713
+ # norm_layer='layernorm2d',
714
+ # norm_eps=1e-6,
715
+ # ),
716
+ # head_type='1d',
717
+ # ),
718
+ # )
719
+ # model_args = dict(img_size=256, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
720
+ # model = _create_vision_transformer_hybrid(
721
+ # 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
722
+ # return model
723
+
724
+ # @register_model
725
+ # def vitamin_xlarge_336(pretrained=False, **kwargs) -> VisionTransformer:
726
+ # backbone = MbConvStages(cfg=VitCfg(
727
+ # embed_dim=(192, 384, 1152),
728
+ # depths=(2, 4, 1),
729
+ # stem_width=192,
730
+ # conv_cfg = VitConvCfg(
731
+ # norm_layer='layernorm2d',
732
+ # norm_eps=1e-6,
733
+ # ),
734
+ # head_type='1d',
735
+ # ),
736
+ # )
737
+ # model_args = dict(img_size=336, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
738
+ # model = _create_vision_transformer_hybrid(
739
+ # 'vitamin_xlarge_256', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
740
+ # return model
741
+
742
+ # @register_model
743
+ # def vitamin_xlarge_384(pretrained=False, **kwargs) -> VisionTransformer:
744
+ # backbone = MbConvStages(cfg=VitCfg(
745
+ # embed_dim=(192, 384, 1152),
746
+ # depths=(2, 4, 1),
747
+ # stem_width=192,
748
+ # conv_cfg = VitConvCfg(
749
+ # norm_layer='layernorm2d',
750
+ # norm_eps=1e-6,
751
+ # ),
752
+ # head_type='1d',
753
+ # ),
754
+ # )
755
+ # model_args = dict(img_size=384, embed_dim=1152, depth=32, num_heads=16, mlp_layer=GeGluMlp, mlp_ratio=2., class_token=False, global_pool='avg')
756
+ # model = _create_vision_transformer_hybrid(
757
+ # 'vitamin_xlarge_384', backbone=backbone, pretrained=pretrained, **dict(model_args, **kwargs))
758
+ # return model
759
+
760
+
761
+ def count_params(model: nn.Module):
762
+ return sum([m.numel() for m in model.parameters()])
763
+
764
+
765
+ def count_stage_params(model: nn.Module, prefix='none'):
766
+ collections = []
767
+ for name, m in model.named_parameters():
768
+ print(name)
769
+ if name.startswith(prefix):
770
+ collections.append(m.numel())
771
+ return sum(collections)
772
+
773
+
774
+ if __name__ == "__main__":
775
+ # ViTaminDecoder('vitamin_base', img_size=256, patch_size=16)
776
+ # model = timm.create_model(
777
+ # 'vitamin_base',
778
+ # fc_norm=True,
779
+ # drop_rate=0.0,
780
+ # num_classes=0,
781
+ # global_pool='',
782
+ # mlp_layer=GeGluMlp,
783
+ # class_token=False,
784
+ # reg_tokens=32,
785
+ # img_size=256,
786
+ # patch_size=1,
787
+ # drop_path_rate=0.1,
788
+ # )
789
+ # print(model.has_class_token)
790
+ # print(model.num_prefix_tokens)
791
+ # print(model.pos_embed.shape)
792
+ Stem(64, 64)
unitok/vqvae.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from contextlib import nullcontext
6
+ from torch.nn.functional import scaled_dot_product_attention
7
+
8
+ from unitok.quant import VectorQuantizerM
9
+ from unitok.vitamin import ViTaminDecoder, GeGluMlp
10
+
11
+
12
+ class PlainAttention(nn.Module):
13
+ def __init__(self, in_dim, out_dim, num_heads):
14
+ super().__init__()
15
+ if in_dim > out_dim:
16
+ # assert in_dim // num_heads == out_dim
17
+ self.head_dim = in_dim // num_heads
18
+ self.qkv = nn.Linear(in_dim, in_dim * 3, bias=False)
19
+ self.q_bias = nn.Parameter(torch.zeros(in_dim))
20
+ self.v_bias = nn.Parameter(torch.zeros(in_dim))
21
+ self.register_buffer('zero_k_bias', torch.zeros(in_dim))
22
+ else:
23
+ # assert out_dim // num_heads == in_dim
24
+ self.head_dim = out_dim // num_heads
25
+ self.qkv = nn.Linear(in_dim, out_dim * 3, bias=False)
26
+ self.q_bias = nn.Parameter(torch.zeros(out_dim))
27
+ self.v_bias = nn.Parameter(torch.zeros(out_dim))
28
+ self.register_buffer('zero_k_bias', torch.zeros(out_dim))
29
+
30
+ self.in_dim = in_dim
31
+ self.out_dim = out_dim
32
+ self.num_heads = num_heads
33
+ self.scale = self.head_dim ** -0.5
34
+ self.proj = nn.Linear(out_dim, out_dim)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ B, N, C = x.shape
38
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias)))
39
+ q, k, v = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
40
+
41
+ x = scaled_dot_product_attention(q, k, v)
42
+
43
+ if self.in_dim > self.out_dim:
44
+ x = torch.mean(x, dim=1)
45
+ if self.in_dim // self.num_heads != self.out_dim:
46
+ x = nn.functional.adaptive_avg_pool1d(x, self.out_dim)
47
+ else:
48
+ x = x.transpose(1, 2).reshape(B, N, -1)
49
+ x = self.proj(x)
50
+ return x
51
+
52
+
53
+ class AttnProjection(nn.Module):
54
+ def __init__(self, in_dim, out_dim, num_heads, norm_layer=nn.LayerNorm, mlp_ratio=2):
55
+ super().__init__()
56
+ assert out_dim % in_dim == 0 or in_dim % out_dim == 0
57
+ self.in_dim = in_dim
58
+ self.out_dim = out_dim
59
+ self.norm1 = norm_layer(in_dim)
60
+ self.attn = PlainAttention(in_dim, out_dim, num_heads)
61
+ self.proj = nn.Linear(in_dim, out_dim)
62
+ self.norm3 = norm_layer(in_dim)
63
+
64
+ self.norm2 = norm_layer(out_dim)
65
+ hidden_dim = int(out_dim * mlp_ratio)
66
+ self.mlp = GeGluMlp(
67
+ in_features=out_dim,
68
+ hidden_features=hidden_dim
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.proj(self.norm3(x)) + self.attn(self.norm1(x))
73
+ x = x + self.mlp(self.norm2(x))
74
+ return x
75
+
76
+
77
+ class VQVAE(nn.Module):
78
+ def __init__(self, args):
79
+ super().__init__()
80
+
81
+ # 1. build encoder
82
+ self.encoder = timm.create_model(
83
+ args.model,
84
+ patch_size=1,
85
+ fc_norm=True,
86
+ drop_rate=0.0,
87
+ num_classes=0,
88
+ global_pool='',
89
+ pos_embed='none',
90
+ class_token=False,
91
+ mlp_layer=GeGluMlp,
92
+ img_size=args.img_size,
93
+ drop_path_rate=args.drop_path,
94
+ )
95
+ self.encoder.set_grad_checkpointing(args.grad_ckpt)
96
+
97
+ # 2. build conv before quant
98
+ if args.quant_proj == 'linear':
99
+ self.quant_proj = nn.Linear(self.encoder.embed_dim, args.vocab_width)
100
+ elif args.quant_proj == 'attn':
101
+ self.quant_proj = AttnProjection(self.encoder.embed_dim, args.vocab_width, args.num_codebooks)
102
+ else:
103
+ raise NotImplementedError
104
+
105
+ # 3. build quant
106
+ self.quantize = VectorQuantizerM(
107
+ vocab_size=args.vocab_size,
108
+ vocab_width=args.vocab_width,
109
+ beta=args.vq_beta,
110
+ use_entropy_loss=args.le > 0,
111
+ entropy_temp=args.e_temp,
112
+ num_codebooks=args.num_codebooks,
113
+ )
114
+
115
+ # 4. build conv after quant
116
+ if args.quant_proj == 'linear':
117
+ self.post_quant_proj = nn.Linear(args.vocab_width, self.encoder.embed_dim)
118
+ elif args.quant_proj == 'attn':
119
+ self.post_quant_proj = AttnProjection(args.vocab_width, self.encoder.embed_dim, args.num_codebooks)
120
+ else:
121
+ raise NotImplementedError
122
+
123
+ # 5. build decoder
124
+ self.decoder = ViTaminDecoder(
125
+ args.model,
126
+ depths=(4, 2),
127
+ img_size=args.img_size,
128
+ drop_path=args.drop_path,
129
+ grad_ckpt=args.grad_ckpt
130
+ )
131
+
132
+ self.maybe_record_function = nullcontext
133
+
134
+ def forward(self, img):
135
+ features = self.encoder(img).float()
136
+ with torch.cuda.amp.autocast(enabled=False):
137
+ features = self.quant_proj(features)
138
+ quant_out = self.quantize(features)
139
+ features, vq_loss, entropy_loss, usages = quant_out
140
+ features = self.post_quant_proj(features)
141
+ rec_img = self.decoder(features).float()
142
+ return rec_img, vq_loss, entropy_loss, usages
143
+
144
+ def img_to_idx(self, img):
145
+ features = self.encoder(img).float()
146
+ features = self.quant_proj(features)
147
+ return self.quantize.f_to_idx(features)
148
+
149
+ def idx_to_img(self, indices):
150
+ features = self.quantize.idx_to_f(indices)
151
+ features = self.post_quant_proj(features)
152
+ img = self.decoder(features).clamp_(-1, 1)
153
+ return img
154
+
155
+ def img_to_reconstructed_img(self, img) -> torch.Tensor:
156
+ features = self.encoder(img).float()
157
+ with torch.cuda.amp.autocast(enabled=False):
158
+ features = self.quant_proj(features)
159
+ quant_out = self.quantize(features)
160
+ features, _, _, _ = quant_out
161
+ features = self.post_quant_proj(features)
162
+ rec_img = self.decoder(features).float().clamp_(-1, 1)
163
+ return rec_img
164
+
165
+
166
+ if __name__ == '__main__':
167
+ for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d,
168
+ nn.ConvTranspose2d):
169
+ setattr(clz, 'reset_parameters', lambda self: None)
170
+
171
+ cnn = VQVAE(channel_num=64, vocab_norm=False)
172
+ from models import init_weights
173
+
174
+ init_weights(cnn, -0.5)
175
+ torch.save(cnn.state_dict(), r'C:\Users\16333\Desktop\PyCharm\vlip\local_output\cnn.pth')