Flying-Lynx commited on
Commit
c315863
·
1 Parent(s): 717fa66

add model code

Browse files
llava/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import LlavaQwenSlowFastForCausalLM
llava/constants.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is from https://github.com/haotian-liu/LLaVA/
2
+
3
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
4
+ WORKER_HEART_BEAT_INTERVAL = 15
5
+
6
+ LOGDIR = "."
7
+
8
+ # Model Constants
9
+ IGNORE_INDEX = -100
10
+ IMAGE_TOKEN_INDEX = -200
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
12
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
13
+ DEFAULT_IM_START_TOKEN = "<im_start>"
14
+ DEFAULT_IM_END_TOKEN = "<im_end>"
15
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
llava/conversation.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ import dataclasses
4
+ from enum import auto, Enum
5
+ from typing import List, Tuple
6
+ import base64
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ import re
10
+
11
+ class SeparatorStyle(Enum):
12
+ """Different separator style."""
13
+ SINGLE = auto()
14
+ TWO = auto()
15
+ MPT = auto()
16
+ PLAIN = auto()
17
+ LLAMA_2 = auto()
18
+ QWEN = auto()
19
+ CHATML = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
30
+ sep: str = "###"
31
+ sep2: str = None
32
+ version: str = "Unknown"
33
+
34
+ skip_next: bool = False
35
+
36
+ def get_prompt(self):
37
+ messages = self.messages
38
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
39
+ messages = self.messages.copy()
40
+ init_role, init_msg = messages[0].copy()
41
+ init_msg = init_msg[0].replace("<image>", "").strip()
42
+ if 'mmtag' in self.version:
43
+ messages[0] = (init_role, init_msg)
44
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
45
+ messages.insert(1, (self.roles[1], "Received."))
46
+ else:
47
+ messages[0] = (init_role, "<image>\n" + init_msg)
48
+
49
+ if self.sep_style == SeparatorStyle.SINGLE:
50
+ ret = self.system + self.sep
51
+ for role, message in messages:
52
+ if message:
53
+ if type(message) is tuple:
54
+ message, _, _ = message
55
+ ret += role + ": " + message + self.sep
56
+ else:
57
+ ret += role + ":"
58
+ elif self.sep_style == SeparatorStyle.TWO:
59
+ seps = [self.sep, self.sep2]
60
+ ret = self.system + seps[0]
61
+ for i, (role, message) in enumerate(messages):
62
+ if message:
63
+ if type(message) is tuple:
64
+ message, _, _ = message
65
+ ret += role + ": " + message + seps[i % 2]
66
+ else:
67
+ ret += role + ":"
68
+
69
+ elif self.sep_style == SeparatorStyle.CHATML:
70
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
71
+ for role, message in messages:
72
+ if message:
73
+ if type(message) is tuple:
74
+ message, images, _ = message
75
+ message = "<image>" * len(images) + message
76
+ ret += role + "\n" + message + self.sep + "\n"
77
+ else:
78
+ ret += role + "\n"
79
+
80
+ elif self.sep_style == SeparatorStyle.MPT:
81
+ ret = self.system + self.sep
82
+ for role, message in messages:
83
+ if message:
84
+ if type(message) is tuple:
85
+ message, _, _ = message
86
+ ret += role + message + self.sep
87
+ else:
88
+ ret += role
89
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
90
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
91
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
92
+ ret = ""
93
+
94
+ for i, (role, message) in enumerate(messages):
95
+ if i == 0:
96
+ assert message, "first message should not be none"
97
+ assert role == self.roles[0], "first message should come from user"
98
+ if message:
99
+ if type(message) is tuple:
100
+ message, _, _ = message
101
+ if i == 0: message = wrap_sys(self.system) + message
102
+ if i % 2 == 0:
103
+ message = wrap_inst(message)
104
+ ret += self.sep + message
105
+ else:
106
+ ret += " " + message + " " + self.sep2
107
+ else:
108
+ ret += ""
109
+ ret = ret.lstrip(self.sep)
110
+ elif self.sep_style == SeparatorStyle.PLAIN:
111
+ seps = [self.sep, self.sep2]
112
+ ret = self.system
113
+ for i, (role, message) in enumerate(messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ ret += message + seps[i % 2]
118
+ else:
119
+ ret += ""
120
+ else:
121
+ raise ValueError(f"Invalid style: {self.sep_style}")
122
+
123
+ return ret
124
+
125
+ def append_message(self, role, message):
126
+ self.messages.append([role, message])
127
+
128
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
129
+ if image_process_mode == "Pad":
130
+ def expand2square(pil_img, background_color=(122, 116, 104)):
131
+ width, height = pil_img.size
132
+ if width == height:
133
+ return pil_img
134
+ elif width > height:
135
+ result = Image.new(pil_img.mode, (width, width), background_color)
136
+ result.paste(pil_img, (0, (width - height) // 2))
137
+ return result
138
+ else:
139
+ result = Image.new(pil_img.mode, (height, height), background_color)
140
+ result.paste(pil_img, ((height - width) // 2, 0))
141
+ return result
142
+ image = expand2square(image)
143
+ elif image_process_mode in ["Default", "Crop"]:
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
149
+ if max(image.size) > max_len:
150
+ max_hw, min_hw = max(image.size), min(image.size)
151
+ aspect_ratio = max_hw / min_hw
152
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
153
+ longest_edge = int(shortest_edge * aspect_ratio)
154
+ W, H = image.size
155
+ if H > W:
156
+ H, W = longest_edge, shortest_edge
157
+ else:
158
+ H, W = shortest_edge, longest_edge
159
+ image = image.resize((W, H))
160
+ if return_pil:
161
+ return image
162
+ else:
163
+ buffered = BytesIO()
164
+ image.save(buffered, format=image_format)
165
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
166
+ return img_b64_str
167
+
168
+ def is_image_file(self, filename):
169
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
170
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
171
+
172
+ def is_video_file(self, filename):
173
+ video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
174
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
175
+
176
+ def get_images(self, return_pil=False):
177
+ images = []
178
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
179
+ if i % 2 == 0:
180
+ if type(msg) is tuple:
181
+ msg, image, image_process_mode = msg
182
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
183
+ images.append(image)
184
+ return images
185
+
186
+ def to_gradio_chatbot(self):
187
+ ret = []
188
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
189
+ if i % 2 == 0:
190
+ if type(msg) is tuple:
191
+ msg, image, image_process_mode = msg
192
+ if type(image) != list:
193
+ image = [image]
194
+ if len(image) == 1:
195
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
196
+ else:
197
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
198
+
199
+ img_str_list = []
200
+ for img in image:
201
+ if self.is_image_file(img):
202
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
203
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
204
+ img_str_list.append(img_str)
205
+ elif self.is_video_file(img):
206
+ ret.append(((img,), None))
207
+
208
+ msg = msg.strip()
209
+ img_place_holder = ""
210
+ for img_str in img_str_list:
211
+ img_place_holder += f"{img_str}\n\n"
212
+
213
+ if len(img_str_list) > 0:
214
+ msg = f"{img_place_holder}\n\n{msg}"
215
+
216
+ if len(msg) > 0:
217
+ ret.append([msg, None])
218
+ else:
219
+ ret.append([msg, None])
220
+ else:
221
+ ret[-1][-1] = msg
222
+ return ret
223
+
224
+ def copy(self):
225
+ return Conversation(
226
+ system=self.system,
227
+ roles=self.roles,
228
+ messages=[[x, y] for x, y in self.messages],
229
+ offset=self.offset,
230
+ sep_style=self.sep_style,
231
+ sep=self.sep,
232
+ sep2=self.sep2,
233
+ version=self.version)
234
+
235
+ def dict(self):
236
+ if len(self.get_images()) > 0:
237
+ return {
238
+ "system": self.system,
239
+ "roles": self.roles,
240
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
241
+ "offset": self.offset,
242
+ "sep": self.sep,
243
+ "sep2": self.sep2,
244
+ }
245
+ return {
246
+ "system": self.system,
247
+ "roles": self.roles,
248
+ "messages": self.messages,
249
+ "offset": self.offset,
250
+ "sep": self.sep,
251
+ "sep2": self.sep2,
252
+ }
253
+
254
+
255
+ conv_vicuna_v0 = Conversation(
256
+ system="A chat between a curious human and an artificial intelligence assistant. "
257
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
258
+ roles=("Human", "Assistant"),
259
+ messages=(
260
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
261
+ ("Assistant",
262
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
263
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
264
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
265
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
266
+ "renewable and non-renewable energy sources:\n"
267
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
268
+ "energy sources are finite and will eventually run out.\n"
269
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
270
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
271
+ "and other negative effects.\n"
272
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
273
+ "have lower operational costs than non-renewable sources.\n"
274
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
275
+ "locations than non-renewable sources.\n"
276
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
277
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
278
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
279
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
280
+ ),
281
+ offset=2,
282
+ sep_style=SeparatorStyle.SINGLE,
283
+ sep="###",
284
+ )
285
+
286
+ conv_vicuna_v1 = Conversation(
287
+ system="A chat between a curious user and an artificial intelligence assistant. "
288
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
289
+ roles=("USER", "ASSISTANT"),
290
+ version="v1",
291
+ messages=(),
292
+ offset=0,
293
+ sep_style=SeparatorStyle.TWO,
294
+ sep=" ",
295
+ sep2="</s>",
296
+ )
297
+
298
+ conv_llama_2 = Conversation(
299
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
300
+
301
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
302
+ roles=("USER", "ASSISTANT"),
303
+ version="llama_v2",
304
+ messages=(),
305
+ offset=0,
306
+ sep_style=SeparatorStyle.LLAMA_2,
307
+ sep="<s>",
308
+ sep2="</s>",
309
+ )
310
+
311
+ conv_llava_llama_2 = Conversation(
312
+ system="You are a helpful language and vision assistant. "
313
+ "You are able to understand the visual content that the user provides, "
314
+ "and assist the user with a variety of tasks using natural language.",
315
+ roles=("USER", "ASSISTANT"),
316
+ version="llama_v2",
317
+ messages=(),
318
+ offset=0,
319
+ sep_style=SeparatorStyle.LLAMA_2,
320
+ sep="<s>",
321
+ sep2="</s>",
322
+ )
323
+
324
+ conv_mpt = Conversation(
325
+ system="""<|im_start|>system
326
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
327
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
328
+ version="mpt",
329
+ messages=(),
330
+ offset=0,
331
+ sep_style=SeparatorStyle.MPT,
332
+ sep="<|im_end|>",
333
+ )
334
+
335
+ conv_qwen = Conversation(
336
+ system="""<|im_start|>system
337
+ You are a helpful assistant.""",
338
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
339
+ version="qwen",
340
+ messages=[],
341
+ offset=0,
342
+ sep_style=SeparatorStyle.CHATML,
343
+ sep="<|im_end|>",
344
+ )
345
+
346
+ conv_llava_plain = Conversation(
347
+ system="",
348
+ roles=("", ""),
349
+ messages=(
350
+ ),
351
+ offset=0,
352
+ sep_style=SeparatorStyle.PLAIN,
353
+ sep="\n",
354
+ )
355
+
356
+ conv_llava_v0 = Conversation(
357
+ system="A chat between a curious human and an artificial intelligence assistant. "
358
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
359
+ roles=("Human", "Assistant"),
360
+ messages=(
361
+ ),
362
+ offset=0,
363
+ sep_style=SeparatorStyle.SINGLE,
364
+ sep="###",
365
+ )
366
+
367
+ conv_llava_v0_mmtag = Conversation(
368
+ system="A chat between a curious user and an artificial intelligence assistant. "
369
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
370
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
371
+ roles=("Human", "Assistant"),
372
+ messages=(
373
+ ),
374
+ offset=0,
375
+ sep_style=SeparatorStyle.SINGLE,
376
+ sep="###",
377
+ version="v0_mmtag",
378
+ )
379
+
380
+ conv_llava_v1 = Conversation(
381
+ system="A chat between a curious human and an artificial intelligence assistant. "
382
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
383
+ roles=("USER", "ASSISTANT"),
384
+ version="v1",
385
+ messages=(),
386
+ offset=0,
387
+ sep_style=SeparatorStyle.TWO,
388
+ sep=" ",
389
+ sep2="</s>",
390
+ )
391
+
392
+ conv_llava_v1_mmtag = Conversation(
393
+ system="A chat between a curious user and an artificial intelligence assistant. "
394
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
395
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
396
+ roles=("USER", "ASSISTANT"),
397
+ messages=(),
398
+ offset=0,
399
+ sep_style=SeparatorStyle.TWO,
400
+ sep=" ",
401
+ sep2="</s>",
402
+ version="v1_mmtag",
403
+ )
404
+
405
+ conv_mistral_instruct = Conversation(
406
+ system="",
407
+ roles=("USER", "ASSISTANT"),
408
+ version="llama_v2",
409
+ messages=(),
410
+ offset=0,
411
+ sep_style=SeparatorStyle.LLAMA_2,
412
+ sep="",
413
+ sep2="</s>",
414
+ )
415
+
416
+ conv_chatml_direct = Conversation(
417
+ system="""<|im_start|>system
418
+ Answer the questions.""",
419
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
420
+ version="mpt",
421
+ messages=(),
422
+ offset=0,
423
+ sep_style=SeparatorStyle.MPT,
424
+ sep="<|im_end|>",
425
+ )
426
+
427
+ conv_yi34b_chatml_direct = Conversation(
428
+ system="""<|im_start|>system
429
+ Answer the questions.""",
430
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
431
+ version="mpt-yi-34b",
432
+ messages=(),
433
+ offset=0,
434
+ sep_style=SeparatorStyle.MPT,
435
+ sep="<|im_end|>",
436
+ )
437
+
438
+ conv_llama3 = Conversation(
439
+ system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
440
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
441
+ version="llama3",
442
+ messages=(),
443
+ offset=0,
444
+ sep_style=SeparatorStyle.MPT,
445
+ sep="<|eot_id|>",
446
+ )
447
+
448
+ conv_chatml_direct = Conversation(
449
+ system="""<|im_start|>system
450
+ Answer the questions.""",
451
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
452
+ version="mpt",
453
+ messages=[],
454
+ offset=0,
455
+ sep_style=SeparatorStyle.MPT,
456
+ sep="<|im_end|>",
457
+ )
458
+
459
+ default_conversation = conv_vicuna_v1
460
+ conv_templates = {
461
+ "default": conv_vicuna_v0,
462
+ "v0": conv_vicuna_v0,
463
+ "v1": conv_vicuna_v1,
464
+ "vicuna_v1": conv_vicuna_v1,
465
+ "llama_2": conv_llama_2,
466
+ "mistral_instruct": conv_mistral_instruct,
467
+ "chatml_direct": conv_chatml_direct,
468
+ "yi_34b_chatml_direct": conv_yi34b_chatml_direct,
469
+ "mistral_direct": conv_chatml_direct,
470
+
471
+ "plain": conv_llava_plain,
472
+ "v0_plain": conv_llava_plain,
473
+ "llava_v0": conv_llava_v0,
474
+ "v0_mmtag": conv_llava_v0_mmtag,
475
+ "llava_v1": conv_llava_v1,
476
+ "v1_mmtag": conv_llava_v1_mmtag,
477
+ "llava_llama_2": conv_llava_llama_2,
478
+
479
+ "mpt": conv_mpt,
480
+ "llama3": conv_llama3,
481
+ "qwen_1_5": conv_qwen,
482
+ "qwen_2": conv_qwen,
483
+ }
484
+
485
+
486
+ if __name__ == "__main__":
487
+ print(default_conversation.get_prompt())
llava/mm_utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import base64
6
+ import torch
7
+ import math
8
+ import ast
9
+
10
+ from transformers import StoppingCriteria
11
+ from llava.constants import IMAGE_TOKEN_INDEX
12
+
13
+ from decord import VideoReader
14
+ from decord import cpu
15
+
16
+ import av
17
+ from av.codec.context import CodecContext
18
+ import numpy as np
19
+
20
+ def get_frame_indices(total_frames, original_fps, target_fps, num_frm):
21
+ sample_fps = round(original_fps / target_fps)
22
+ frame_idx = [i for i in range(0, total_frames, sample_fps)]
23
+ if len(frame_idx) < num_frm:
24
+ # If we have fewer frames than num_frm, just return all the frames
25
+ return frame_idx
26
+ scale = 1.0 * len(frame_idx) / num_frm
27
+ uniform_idx = [round((i + 1) * scale - 1) for i in range(num_frm)]
28
+ frame_idx = [frame_idx[i] for i in uniform_idx]
29
+ return frame_idx
30
+
31
+ def read_video_decord(video_path, num_frm=16, target_fps=2):
32
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
33
+ total_frames = len(vr)
34
+ original_fps = vr.get_avg_fps()
35
+
36
+ target_fps = min(target_fps, original_fps) # target fps should not exceed the video fps
37
+ indices = get_frame_indices(total_frames, original_fps, target_fps, num_frm)
38
+ frames = vr.get_batch(indices)
39
+ vr.seek(0)
40
+
41
+ # video info string
42
+ total_time = total_frames/original_fps
43
+ video_info_string = f"Time: {round(total_time, 2)}s; Time interval between frame {round(total_time/len(indices),3)}s; video tokens:"
44
+
45
+ return frames.asnumpy(), video_info_string
46
+
47
+ def read_video_pyav2(video_path, num_frm=16, target_fps=1, threads=4):
48
+ container = av.open(video_path)
49
+ stream = container.streams.video[0]
50
+
51
+ stream.thread_type = 'AUTO'
52
+ stream.codec_context.thread_count = threads
53
+
54
+ original_fps = stream.average_rate
55
+ total_frames = stream.frames
56
+
57
+ if "webm" not in video_path and "mkv" not in video_path:
58
+ try:
59
+ indices = get_frame_indices(total_frames, original_fps, target_fps, num_frm)
60
+ frames = record_video_length_stream(container, indices)
61
+ except:
62
+ container = av.open(video_path)
63
+ frames = record_video_length_packet(container)
64
+ total_frames = len(frames)
65
+ indices = get_frame_indices(total_frames, original_fps, target_fps, num_frm)
66
+ frames = [frames[i] for i in indices]
67
+ else:
68
+ frames = record_video_length_packet(container)
69
+ total_frames = len(frames)
70
+ indices = get_frame_indices(total_frames, original_fps, target_fps, num_frm)
71
+ frames = [frames[i] for i in indices]
72
+
73
+ return np.stack([x.to_ndarray(format="rgb24") for x in frames])
74
+
75
+ # This one is faster
76
+ def record_video_length_stream(container, indices):
77
+ frames = []
78
+ start_index = indices[0]
79
+ end_index = indices[-1]
80
+ for i, frame in enumerate(container.decode(video=0)):
81
+ if i > end_index:
82
+ break
83
+ if i >= start_index and i in indices:
84
+ frames.append(frame)
85
+ return frames
86
+
87
+
88
+ # This one works for all types of video
89
+ def record_video_length_packet(container):
90
+ frames = []
91
+ # https://github.com/PyAV-Org/PyAV/issues/1269
92
+ # https://www.cnblogs.com/beyond-tester/p/17641872.html
93
+ # context = CodecContext.create("libvpx-vp9", "r")
94
+ for packet in container.demux(video=0):
95
+ for frame in packet.decode():
96
+ frames.append(frame)
97
+ return frames
98
+
99
+
100
+ def select_best_resolution(original_size, possible_resolutions):
101
+ """
102
+ Selects the best resolution from a list of possible resolutions based on the original size.
103
+
104
+ Args:
105
+ original_size (tuple): The original size of the image in the format (width, height).
106
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
107
+
108
+ Returns:
109
+ tuple: The best fit resolution in the format (width, height).
110
+ """
111
+ original_width, original_height = original_size
112
+ best_fit = None
113
+ max_effective_resolution = 0
114
+ min_wasted_resolution = float('inf')
115
+
116
+ for width, height in possible_resolutions:
117
+ scale = min(width / original_width, height / original_height)
118
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
119
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
120
+ wasted_resolution = (width * height) - effective_resolution
121
+
122
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
123
+ max_effective_resolution = effective_resolution
124
+ min_wasted_resolution = wasted_resolution
125
+ best_fit = (width, height)
126
+
127
+ return best_fit
128
+
129
+
130
+ def resize_and_pad_image(image, target_resolution):
131
+ """
132
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
133
+
134
+ Args:
135
+ image (PIL.Image.Image): The input image.
136
+ target_resolution (tuple): The target resolution (width, height) of the image.
137
+
138
+ Returns:
139
+ PIL.Image.Image: The resized and padded image.
140
+ """
141
+ original_width, original_height = image.size
142
+ target_width, target_height = target_resolution
143
+
144
+ scale_w = target_width / original_width
145
+ scale_h = target_height / original_height
146
+
147
+ if scale_w < scale_h:
148
+ new_width = target_width
149
+ new_height = min(math.ceil(original_height * scale_w), target_height)
150
+ else:
151
+ new_height = target_height
152
+ new_width = min(math.ceil(original_width * scale_h), target_width)
153
+
154
+ # Resize the image
155
+ resized_image = image.resize((new_width, new_height))
156
+
157
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
158
+ paste_x = (target_width - new_width) // 2
159
+ paste_y = (target_height - new_height) // 2
160
+ new_image.paste(resized_image, (paste_x, paste_y))
161
+
162
+ return new_image
163
+
164
+
165
+ def divide_to_patches(image, patch_size):
166
+ """
167
+ Divides an image into patches of a specified size.
168
+
169
+ Args:
170
+ image (PIL.Image.Image): The input image.
171
+ patch_size (int): The size of each patch.
172
+
173
+ Returns:
174
+ list: A list of PIL.Image.Image objects representing the patches.
175
+ """
176
+ patches = []
177
+ width, height = image.size
178
+ for i in range(0, height, patch_size):
179
+ for j in range(0, width, patch_size):
180
+ box = (j, i, j + patch_size, i + patch_size)
181
+ patch = image.crop(box)
182
+ patches.append(patch)
183
+
184
+ return patches
185
+
186
+
187
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
188
+ """
189
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
190
+
191
+ Args:
192
+ image_size (tuple): The size of the input image in the format (width, height).
193
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
194
+ patch_size (int): The size of each image patch.
195
+
196
+ Returns:
197
+ tuple: The shape of the image patch grid in the format (width, height).
198
+ """
199
+ if type(grid_pinpoints) is list:
200
+ possible_resolutions = grid_pinpoints
201
+ else:
202
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
203
+ width, height = select_best_resolution(image_size, possible_resolutions)
204
+ return width // patch_size, height // patch_size
205
+
206
+
207
+ def process_anyres_image(image, processor, grid_pinpoints):
208
+ """
209
+ Process an image with variable resolutions.
210
+
211
+ Args:
212
+ image (PIL.Image.Image): The input image to be processed.
213
+ processor: The image processor object.
214
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
215
+
216
+ Returns:
217
+ torch.Tensor: A tensor containing the processed image patches.
218
+ """
219
+ if type(grid_pinpoints) is list:
220
+ possible_resolutions = grid_pinpoints
221
+ else:
222
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
223
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
224
+ image_padded = resize_and_pad_image(image, best_resolution)
225
+
226
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
227
+
228
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
229
+
230
+ image_patches = [image_original_resize] + patches
231
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
232
+ for image_patch in image_patches]
233
+ return torch.stack(image_patches, dim=0)
234
+
235
+
236
+ def load_image_from_base64(image):
237
+ return Image.open(BytesIO(base64.b64decode(image)))
238
+
239
+
240
+ def expand2square(pil_img, background_color):
241
+ width, height = pil_img.size
242
+ if width == height:
243
+ return pil_img
244
+ elif width > height:
245
+ result = Image.new(pil_img.mode, (width, width), background_color)
246
+ result.paste(pil_img, (0, (width - height) // 2))
247
+ return result
248
+ else:
249
+ result = Image.new(pil_img.mode, (height, height), background_color)
250
+ result.paste(pil_img, ((height - width) // 2, 0))
251
+ return result
252
+
253
+
254
+ def process_images(images, image_processor, model_cfg):
255
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
256
+ new_images = []
257
+ if image_aspect_ratio == 'pad':
258
+ for image in images:
259
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
260
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
261
+ new_images.append(image)
262
+ elif image_aspect_ratio == "anyres":
263
+ for image in images:
264
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
265
+ new_images.append(image)
266
+ else:
267
+ return image_processor(images, return_tensors='pt')['pixel_values']
268
+ if all(x.shape == new_images[0].shape for x in new_images):
269
+ new_images = torch.stack(new_images, dim=0)
270
+ return new_images
271
+
272
+
273
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
274
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
275
+
276
+ def insert_separator(X, sep):
277
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
278
+
279
+ input_ids = []
280
+ offset = 0
281
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
282
+ offset = 1
283
+ input_ids.append(prompt_chunks[0][0])
284
+
285
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
286
+ input_ids.extend(x[offset:])
287
+
288
+ if return_tensors is not None:
289
+ if return_tensors == 'pt':
290
+ return torch.tensor(input_ids, dtype=torch.long)
291
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
292
+ return input_ids
293
+
294
+
295
+ def get_model_name_from_path(model_path):
296
+ model_path = model_path.strip("/")
297
+ model_paths = model_path.split("/")
298
+ if model_paths[-1].startswith('checkpoint-'):
299
+ return model_paths[-2] + "_" + model_paths[-1]
300
+ else:
301
+ return model_paths[-1]
302
+
303
+ class KeywordsStoppingCriteria(StoppingCriteria):
304
+ def __init__(self, keywords, tokenizer, input_ids):
305
+ self.keywords = keywords
306
+ self.keyword_ids = []
307
+ self.max_keyword_len = 0
308
+ for keyword in keywords:
309
+ cur_keyword_ids = tokenizer(keyword).input_ids
310
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
311
+ cur_keyword_ids = cur_keyword_ids[1:]
312
+ if len(cur_keyword_ids) > self.max_keyword_len:
313
+ self.max_keyword_len = len(cur_keyword_ids)
314
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
315
+ self.tokenizer = tokenizer
316
+ self.start_len = input_ids.shape[1]
317
+
318
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
319
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
320
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
321
+ for keyword_id in self.keyword_ids:
322
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
323
+ if torch.equal(truncated_output_ids, keyword_id):
324
+ return True
325
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
326
+ for keyword in self.keywords:
327
+ if keyword in outputs:
328
+ return True
329
+ return False
330
+
331
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
332
+ outputs = []
333
+ for i in range(output_ids.shape[0]):
334
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
335
+ return all(outputs)
llava/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.llava_qwen2 import LlavaQwenForCausalLM, LlavaQwenConfig
2
+ from .language_model.llava_qwen2_slowfast import LlavaQwenSlowFastForCausalLM, LlavaQwenSlowFastConfig
llava/model/builder.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ # Copyright 2023 Haotian Liu
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import os
19
+ import warnings
20
+ import shutil
21
+
22
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
23
+ import torch
24
+ from llava.model import *
25
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
26
+
27
+
28
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
29
+ kwargs = {"device_map": device_map, **kwargs}
30
+
31
+ if device != "cuda":
32
+ kwargs['device_map'] = {"": device}
33
+
34
+ if load_8bit:
35
+ kwargs['load_in_8bit'] = True
36
+ elif load_4bit:
37
+ kwargs['load_in_4bit'] = True
38
+ kwargs['quantization_config'] = BitsAndBytesConfig(
39
+ load_in_4bit=True,
40
+ bnb_4bit_compute_dtype=torch.float16,
41
+ bnb_4bit_use_double_quant=True,
42
+ bnb_4bit_quant_type='nf4'
43
+ )
44
+ else:
45
+ kwargs['torch_dtype'] = torch.float16
46
+
47
+ if use_flash_attn:
48
+ kwargs['attn_implementation'] = 'flash_attention_2'
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
51
+ model = LlavaQwenSlowFastForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
52
+
53
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
54
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
55
+ if mm_use_im_patch_token:
56
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
57
+ if mm_use_im_start_end:
58
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
59
+ model.resize_token_embeddings(len(tokenizer))
60
+
61
+ vision_tower = model.get_vision_tower()
62
+ if not vision_tower.is_loaded:
63
+ vision_tower.load_model(device_map=device_map)
64
+ if device_map != 'auto':
65
+ vision_tower.to(device=device_map, dtype=torch.float16)
66
+ image_processor = vision_tower.image_processor
67
+
68
+ if hasattr(model.config, "max_sequence_length"):
69
+ context_len = model.config.max_sequence_length
70
+ else:
71
+ context_len = 2048
72
+
73
+ return tokenizer, model, image_processor, context_len
llava/model/language_model/hybrid_decoder_layer.py ADDED
@@ -0,0 +1,1473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch Qwen2 model."""
2
+
3
+ import math
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from einops import rearrange
10
+
11
+ from transformers.cache_utils import Cache
12
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
13
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
14
+ from transformers.utils import (
15
+ is_flash_attn_2_available,
16
+ is_flash_attn_greater_or_equal_2_10
17
+ )
18
+ from transformers.activations import ACT2FN
19
+
20
+ if is_flash_attn_2_available():
21
+ from flash_attn.bert_padding import index_first_axis
22
+ from flash_attn import flash_attn_varlen_func
23
+
24
+
25
+ class ScaleDotProductCrossAttention(nn.Module):
26
+
27
+ def __init__(self, layer_number, softmax_scale=None, attention_dropout=0.0):
28
+ super().__init__()
29
+ self.layer_number = layer_number
30
+ self.softmax_scale = softmax_scale
31
+ self.dropout_p = attention_dropout
32
+
33
+ def forward(self, q, k, v, attn_mask=None):
34
+ """Implements the multihead softmax attention.
35
+ Arguments
36
+ ---------
37
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
38
+ """
39
+ # (N,...,L,E)
40
+
41
+ if attn_mask is not None:
42
+ attn_mask = attn_mask[:,None,:,:].repeat(1, q.shape[1], 1, 1)
43
+
44
+ # attention mask, True means it will take part in attention B H s_q s_k
45
+ if self.training:
46
+ dropout_p = self.dropout_p
47
+ else:
48
+ dropout_p = 0.0
49
+
50
+ if q.device.type == "cuda" and attn_mask is not None:
51
+ q = q.contiguous()
52
+ k = k.contiguous()
53
+ v = v.contiguous()
54
+
55
+ # debug only, calculate the FLOPs for cross-attn
56
+ ##################
57
+ # attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(128) # hardcode
58
+ # if attn_mask is not None: # no matter the length, we just slice it
59
+ # causal_mask = attn_mask[:, :, :, : k.shape[-2]]
60
+ # attn_weights = attn_weights + causal_mask
61
+
62
+ # # upcast attention to fp32
63
+ # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
64
+ # # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
65
+ # o = torch.matmul(attn_weights, v)
66
+ ###################
67
+
68
+ o = nn.functional.scaled_dot_product_attention(q, k, v,
69
+ attn_mask=attn_mask,
70
+ dropout_p=dropout_p,
71
+ is_causal=False,
72
+ scale=self.softmax_scale)
73
+
74
+ # B Head L D -> L B (Head D)
75
+ o = rearrange(o, 'B Head L D -> B L (Head D)').contiguous()
76
+
77
+ return o
78
+
79
+ class FlashAttnCrossAttention(nn.Module):
80
+
81
+ def __init__(self, layer_number, softmax_scale=None, attention_dropout=0.0):
82
+ super().__init__()
83
+ self.layer_number = layer_number
84
+ self.softmax_scale = softmax_scale
85
+ self.dropout_p = attention_dropout
86
+
87
+ def _get_unpad_data(self, attention_mask: torch.Tensor):
88
+ """
89
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
90
+
91
+ Arguments:
92
+ attention_mask (`torch.Tensor`):
93
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
94
+
95
+ Return:
96
+ indices (`torch.Tensor`):
97
+ The indices of non-masked tokens from the flattened input sequence.
98
+ cu_seqlens (`torch.Tensor`):
99
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
100
+ max_seqlen_in_batch (`int`):
101
+ Maximum sequence length in batch.
102
+ """
103
+ seqlens_in_batch = attention_mask[:, 0, :].sum(dim=-1, dtype=torch.int32) # attn mask are the same for the query dimension, pick the first query
104
+ indices = torch.nonzero(attention_mask[:, 0, :].flatten(), as_tuple=False).flatten()
105
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
106
+ cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
107
+ return (
108
+ indices,
109
+ cu_seqlens,
110
+ max_seqlen_in_batch,
111
+ seqlens_in_batch
112
+ )
113
+ def unpad_q(self, q_layer):
114
+ # no need to unpad, just flatten
115
+
116
+ batch_size, q_seq_len, num_key_value_heads, head_dim = q_layer.shape
117
+ cu_seqlens_q = torch.tensor([q_seq_len] * batch_size, dtype=torch.int32, device=q_layer.device)
118
+ cu_seqlens_q = nn.functional.pad(torch.cumsum(cu_seqlens_q, dim=0, dtype=torch.int32), (1, 0))
119
+ q_layer = q_layer.reshape(batch_size * q_seq_len, num_key_value_heads, head_dim)
120
+
121
+ return (
122
+ q_layer,
123
+ cu_seqlens_q,
124
+ q_seq_len)
125
+ def unpad_kv(self, key_layer, value_layer, attn_mask):
126
+
127
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k, split_size = self._get_unpad_data(attn_mask)
128
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
129
+
130
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
131
+ value_layer = index_first_axis(
132
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
133
+ )
134
+
135
+ return (
136
+ key_layer,
137
+ value_layer,
138
+ indices_k,
139
+ cu_seqlens_k,
140
+ max_seqlen_in_batch_k,
141
+ split_size)
142
+
143
+ def forward(self, q, k, v, attn_mask=None):
144
+ """
145
+ Implements the multihead softmax attention with flash attention varlen api.
146
+ Unpad the kv sequence
147
+ Arguments
148
+ ---------
149
+ q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
150
+ """
151
+ # (N,...,L,E)
152
+ q = q.transpose(1, 2)
153
+ k = k.transpose(1, 2)
154
+ v = v.transpose(1, 2)
155
+
156
+ # NOTE: don't know if it's necessary
157
+ if q.device.type == "cuda" and attn_mask is not None:
158
+ q = q.contiguous()
159
+ k = k.contiguous()
160
+ v = v.contiguous()
161
+
162
+ # batch_size = q.shape[0]
163
+ # first unpad the q and kv, get cu_seq_len and indices
164
+ batch_size, q_seq_len, head_num, head_dim = q.shape
165
+ q, cu_seq_lens_q, max_seqlen_in_batch_q = self.unpad_q(q)
166
+ k, v, indices_kv, cu_seq_lens_kv, max_seqlen_in_batch_kv, split_size = self.unpad_kv(k, v, attn_mask)
167
+
168
+ attn_output = flash_attn_varlen_func(
169
+ q,
170
+ k,
171
+ v,
172
+ cu_seqlens_q=cu_seq_lens_q,
173
+ cu_seqlens_k=cu_seq_lens_kv,
174
+ max_seqlen_q=max_seqlen_in_batch_q,
175
+ max_seqlen_k=max_seqlen_in_batch_kv,
176
+ dropout_p=self.dropout_p if self.training else 0.0,
177
+ softmax_scale=None,
178
+ causal=False,
179
+ # **flash_kwargs
180
+ )
181
+
182
+ return attn_output.reshape(batch_size, q_seq_len, head_num, head_dim).flatten(2, 3).contiguous()
183
+
184
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
185
+ class Qwen2RMSNorm(nn.Module):
186
+ def __init__(self, hidden_size, eps=1e-6):
187
+ """
188
+ Qwen2RMSNorm is equivalent to T5LayerNorm
189
+ """
190
+ super().__init__()
191
+ self.weight = nn.Parameter(torch.ones(hidden_size))
192
+ self.variance_epsilon = eps
193
+
194
+ def forward(self, hidden_states):
195
+ input_dtype = hidden_states.dtype
196
+ hidden_states = hidden_states.to(torch.float32)
197
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
198
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
199
+ return self.weight * hidden_states.to(input_dtype)
200
+
201
+ def extra_repr(self):
202
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
203
+
204
+
205
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
206
+ class Qwen2RotaryEmbedding(nn.Module):
207
+ def __init__(
208
+ self,
209
+ dim=None,
210
+ max_position_embeddings=2048,
211
+ base=10000,
212
+ device=None,
213
+ scaling_factor=1.0,
214
+ rope_type="default",
215
+ config=None,
216
+ ):
217
+ super().__init__()
218
+ # TODO (joao): remove the `if` below, only used for BC
219
+ self.rope_kwargs = {}
220
+ if config is None:
221
+ self.rope_kwargs = {
222
+ "rope_type": rope_type,
223
+ "factor": scaling_factor,
224
+ "dim": dim,
225
+ "base": base,
226
+ "max_position_embeddings": max_position_embeddings,
227
+ }
228
+ self.rope_type = rope_type
229
+ self.max_seq_len_cached = max_position_embeddings
230
+ self.original_max_seq_len = max_position_embeddings
231
+ else:
232
+ # BC: "rope_type" was originally "type"
233
+ if config.rope_scaling is not None:
234
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
235
+ else:
236
+ self.rope_type = "default"
237
+ self.max_seq_len_cached = config.max_position_embeddings
238
+ self.original_max_seq_len = config.max_position_embeddings
239
+
240
+ self.config = config
241
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
242
+
243
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
244
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
245
+ self.original_inv_freq = self.inv_freq
246
+
247
+ def _dynamic_frequency_update(self, position_ids, device):
248
+ """
249
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
250
+ 1 - growing beyond the cached sequence length (allow scaling)
251
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
252
+ """
253
+ seq_len = torch.max(position_ids) + 1
254
+ if seq_len > self.max_seq_len_cached: # growth
255
+ inv_freq, self.attention_scaling = self.rope_init_fn(
256
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
257
+ )
258
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
259
+ self.max_seq_len_cached = seq_len
260
+
261
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
262
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
263
+ self.max_seq_len_cached = self.original_max_seq_len
264
+
265
+ @torch.no_grad()
266
+ def forward(self, x, position_ids):
267
+ if "dynamic" in self.rope_type:
268
+ self._dynamic_frequency_update(position_ids, device=x.device)
269
+
270
+ # Core RoPE block
271
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
272
+ position_ids_expanded = position_ids[:, None, :].float()
273
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
274
+ device_type = x.device.type
275
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
276
+ with torch.autocast(device_type=device_type, enabled=False):
277
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
278
+ emb = torch.cat((freqs, freqs), dim=-1)
279
+ cos = emb.cos()
280
+ sin = emb.sin()
281
+
282
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
283
+ cos = cos * self.attention_scaling
284
+ sin = sin * self.attention_scaling
285
+
286
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
287
+
288
+
289
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
290
+ def rotate_half(x):
291
+ """Rotates half the hidden dims of the input."""
292
+ x1 = x[..., : x.shape[-1] // 2]
293
+ x2 = x[..., x.shape[-1] // 2 :]
294
+ return torch.cat((-x2, x1), dim=-1)
295
+
296
+
297
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
298
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
299
+ """Applies Rotary Position Embedding to the query and key tensors.
300
+
301
+ Args:
302
+ q (`torch.Tensor`): The query tensor.
303
+ k (`torch.Tensor`): The key tensor.
304
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
305
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
306
+ position_ids (`torch.Tensor`, *optional*):
307
+ Deprecated and unused.
308
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
309
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
310
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
311
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
312
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
313
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
314
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
315
+ Returns:
316
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
317
+ """
318
+ cos = cos.unsqueeze(unsqueeze_dim)
319
+ sin = sin.unsqueeze(unsqueeze_dim)
320
+ q_embed = (q * cos) + (rotate_half(q) * sin)
321
+ k_embed = (k * cos) + (rotate_half(k) * sin)
322
+ return q_embed, k_embed
323
+
324
+
325
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
326
+ class Qwen2MLP(nn.Module):
327
+ def __init__(self, config):
328
+ super().__init__()
329
+ self.hidden_size = config.hidden_size
330
+ self.intermediate_size = config.intermediate_size
331
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
332
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
333
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
334
+ self.act_fn = ACT2FN[config.hidden_act]
335
+
336
+ def forward(self, hidden_state):
337
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
338
+
339
+
340
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
341
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
342
+ """
343
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
344
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
345
+ """
346
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
347
+ if n_rep == 1:
348
+ return hidden_states
349
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
350
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
351
+
352
+
353
+ class Qwen2Attention(nn.Module):
354
+ """
355
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
356
+ and "Generating Long Sequences with Sparse Transformers".
357
+ """
358
+
359
+ def __init__(self, config, layer_idx: Optional[int] = None):
360
+ super().__init__()
361
+ self.config = config
362
+ self.layer_idx = layer_idx
363
+ # if layer_idx is None:
364
+ # logger.warning_once(
365
+ # f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
366
+ # "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
367
+ # "when creating this class."
368
+ # )
369
+
370
+ self.hidden_size = config.hidden_size
371
+ self.num_heads = config.num_attention_heads
372
+ self.head_dim = self.hidden_size // self.num_heads
373
+ self.num_key_value_heads = config.num_key_value_heads
374
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
375
+ self.max_position_embeddings = config.max_position_embeddings
376
+ self.rope_theta = config.rope_theta
377
+ self.is_causal = True
378
+ self.attention_dropout = config.attention_dropout
379
+
380
+ if (self.head_dim * self.num_heads) != self.hidden_size:
381
+ raise ValueError(
382
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
383
+ f" and `num_heads`: {self.num_heads})."
384
+ )
385
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
386
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
387
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
388
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
389
+
390
+ self.rotary_emb = Qwen2RotaryEmbedding(config=self.config)
391
+
392
+ def forward(
393
+ self,
394
+ hidden_states: torch.Tensor,
395
+ attention_mask: Optional[torch.Tensor] = None,
396
+ position_ids: Optional[torch.LongTensor] = None,
397
+ past_key_value: Optional[Cache] = None,
398
+ output_attentions: bool = False,
399
+ use_cache: bool = False,
400
+ cache_position: Optional[torch.LongTensor] = None,
401
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
402
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
403
+ bsz, q_len, _ = hidden_states.size()
404
+
405
+ query_states = self.q_proj(hidden_states)
406
+ key_states = self.k_proj(hidden_states)
407
+ value_states = self.v_proj(hidden_states)
408
+
409
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
410
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
411
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
412
+
413
+ if position_embeddings is None:
414
+ # logger.warning_once(
415
+ # "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
416
+ # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
417
+ # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
418
+ # "removed and `position_embeddings` will be mandatory."
419
+ # )
420
+ cos, sin = self.rotary_emb(value_states, position_ids)
421
+ else:
422
+ cos, sin = position_embeddings
423
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
424
+
425
+ if past_key_value is not None:
426
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
427
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
428
+
429
+ # repeat k/v heads if n_kv_heads < n_heads
430
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
431
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
432
+
433
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
434
+ if attention_mask is not None: # no matter the length, we just slice it
435
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
436
+ attn_weights = attn_weights + causal_mask
437
+
438
+ # upcast attention to fp32
439
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
440
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
441
+ attn_output = torch.matmul(attn_weights, value_states)
442
+
443
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
444
+ raise ValueError(
445
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
446
+ f" {attn_output.size()}"
447
+ )
448
+
449
+ attn_output = attn_output.transpose(1, 2).contiguous()
450
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
451
+
452
+ attn_output = self.o_proj(attn_output)
453
+
454
+ if not output_attentions:
455
+ attn_weights = None
456
+
457
+ return attn_output, attn_weights, past_key_value
458
+
459
+
460
+ class Qwen2FlashAttention2(Qwen2Attention):
461
+ """
462
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
463
+ as the weights of the module stays untouched. The only required change would be on the forward pass
464
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
465
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
466
+ config.max_window_layers layers.
467
+ """
468
+
469
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
470
+ def __init__(self, *args, **kwargs):
471
+ super().__init__(*args, **kwargs)
472
+
473
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
474
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
475
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
476
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ attention_mask: Optional[torch.Tensor] = None,
482
+ position_ids: Optional[torch.LongTensor] = None,
483
+ past_key_value: Optional[Cache] = None,
484
+ output_attentions: bool = False,
485
+ use_cache: bool = False,
486
+ cache_position: Optional[torch.LongTensor] = None,
487
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
488
+ ):
489
+ bsz, q_len, _ = hidden_states.size()
490
+
491
+ query_states = self.q_proj(hidden_states)
492
+ key_states = self.k_proj(hidden_states)
493
+ value_states = self.v_proj(hidden_states)
494
+
495
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
496
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
497
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
498
+
499
+ if position_embeddings is None:
500
+ # logger.warning_once(
501
+ # "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
502
+ # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
503
+ # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
504
+ # "removed and `position_embeddings` will be mandatory."
505
+ # )
506
+ cos, sin = self.rotary_emb(value_states, position_ids)
507
+ else:
508
+ cos, sin = position_embeddings
509
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
510
+
511
+ if past_key_value is not None:
512
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
513
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
514
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
515
+ if (
516
+ getattr(self.config, "sliding_window", None) is not None
517
+ and kv_seq_len > self.config.sliding_window
518
+ and cache_has_contents
519
+ ):
520
+ slicing_tokens = 1 - self.config.sliding_window
521
+
522
+ past_key = past_key_value[self.layer_idx][0]
523
+ past_value = past_key_value[self.layer_idx][1]
524
+
525
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
526
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
527
+
528
+ if past_key.shape[-2] != self.config.sliding_window - 1:
529
+ raise ValueError(
530
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
531
+ f" {past_key.shape}"
532
+ )
533
+
534
+ if attention_mask is not None:
535
+ attention_mask = attention_mask[:, slicing_tokens:]
536
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
537
+
538
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
539
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
540
+
541
+ # repeat k/v heads if n_kv_heads < n_heads
542
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
543
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
544
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
545
+
546
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
547
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
548
+ # cast them back in float16 just to be sure everything works as expected.
549
+ input_dtype = query_states.dtype
550
+ if input_dtype == torch.float32:
551
+ if torch.is_autocast_enabled():
552
+ target_dtype = torch.get_autocast_gpu_dtype()
553
+ # Handle the case where the model is quantized
554
+ elif hasattr(self.config, "_pre_quantization_dtype"):
555
+ target_dtype = self.config._pre_quantization_dtype
556
+ else:
557
+ target_dtype = self.q_proj.weight.dtype
558
+
559
+ # logger.warning_once(
560
+ # f"The input hidden states seems to be silently casted in float32, this might be related to"
561
+ # f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
562
+ # f" {target_dtype}."
563
+ # )
564
+
565
+ query_states = query_states.to(target_dtype)
566
+ key_states = key_states.to(target_dtype)
567
+ value_states = value_states.to(target_dtype)
568
+
569
+ # Reashape to the expected shape for Flash Attention
570
+ query_states = query_states.transpose(1, 2)
571
+ key_states = key_states.transpose(1, 2)
572
+ value_states = value_states.transpose(1, 2)
573
+
574
+ if (
575
+ self.config.use_sliding_window
576
+ and getattr(self.config, "sliding_window", None) is not None
577
+ and self.layer_idx >= self.config.max_window_layers
578
+ ):
579
+ sliding_window = self.config.sliding_window
580
+ else:
581
+ sliding_window = None
582
+
583
+ attn_output = _flash_attention_forward(
584
+ query_states,
585
+ key_states,
586
+ value_states,
587
+ attention_mask,
588
+ q_len,
589
+ position_ids=position_ids,
590
+ dropout=dropout_rate,
591
+ sliding_window=sliding_window,
592
+ is_causal=self.is_causal,
593
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
594
+ )
595
+
596
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
597
+ attn_output = self.o_proj(attn_output)
598
+
599
+ if not output_attentions:
600
+ attn_weights = None
601
+
602
+ return attn_output, attn_weights, past_key_value
603
+
604
+
605
+ class Qwen2HybridFlashAttention2(Qwen2FlashAttention2):
606
+ """
607
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
608
+ as the weights of the module stays untouched. The only required change would be on the forward pass
609
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
610
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
611
+ config.max_window_layers layers.
612
+ """
613
+
614
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
615
+ def __init__(self,
616
+ is_hyper_enabled,
617
+ gating_type,
618
+ cross_attn_implementation,
619
+ *args, **kwargs):
620
+ super().__init__(*args, **kwargs)
621
+
622
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
623
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
624
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
625
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
626
+
627
+ self.is_hyper_enabled = is_hyper_enabled
628
+ if self.is_hyper_enabled:
629
+ self.gating_type = gating_type
630
+ self.cross_attention_implementation = cross_attn_implementation
631
+ self.cross_attn_kv_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim * 2, bias=True)
632
+
633
+ if gating_type.startswith("whole-dynamic"):
634
+ if "tanh" in gating_type:
635
+ self.cross_attn_gate_proj = nn.Sequential(
636
+ nn.Linear(self.hidden_size, 1),
637
+ nn.Tanh()
638
+ )
639
+ else:
640
+ self.cross_attn_gate_proj = nn.Sequential(
641
+ nn.Linear(self.hidden_size, 1),
642
+ )
643
+
644
+ if gating_type.endswith("warmup"):
645
+ self.cross_attn_warm_up_gate = torch.nn.Parameter(torch.zeros(1))
646
+
647
+ if "flashattn" in self.cross_attention_implementation:
648
+ self.cross_attn_core_attention = FlashAttnCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout)
649
+ else:
650
+ self.cross_attn_core_attention = ScaleDotProductCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout)
651
+
652
+
653
+ def all2media_cross_attn(self,
654
+ text_state,
655
+ text_query,
656
+ vision_features,
657
+ text2vision_cross_attn_mask=None,
658
+ all_text_mask=None):
659
+ '''
660
+ text_query: [s b h d]
661
+ text_state: s b d
662
+ vision_features: [num_vis, b, d]
663
+ '''
664
+
665
+ if vision_features is None or (self.is_hyper_enabled == False):
666
+ return text_state
667
+
668
+ L_c, B_c = text_state.shape[:2]
669
+ D_head = self.head_dim
670
+
671
+ if "whole-dynamic" in self.gating_type:
672
+ gate_value = self.cross_attn_gate_proj(text_state) # n, bs, head_D
673
+ if "warmup" in self.gating_type:
674
+ gate_value = gate_value * self.cross_attn_warm_up_gate
675
+
676
+ vision_features = vision_features.contiguous()
677
+ vision_features = self.cross_attn_kv_proj(vision_features)
678
+ text_query = rearrange(text_query, 'L B H D -> B H L D') # [25, 2, 32, 128])
679
+
680
+ vision_kv = rearrange(vision_features, 'BL Lv (H KV D) -> KV BL H Lv D', KV=2, H=self.num_key_value_heads)
681
+ vision_key = vision_kv[0].contiguous() # [b h s d]
682
+ vision_value = vision_kv[1].contiguous()
683
+
684
+ vision_key = repeat_kv(vision_key, self.num_key_value_groups)
685
+ vision_value = repeat_kv(vision_value, self.num_key_value_groups)
686
+
687
+ # expend_cross_attn_mask
688
+ attention_mask = text2vision_cross_attn_mask[:, None, :].repeat(1, text_state.shape[0], 1)
689
+ vision_context = self.cross_attn_core_attention(text_query, vision_key, vision_value, attn_mask=attention_mask).transpose(0, 1)
690
+
691
+ # mask out the output if a sample is pure text
692
+ vision_context = all_text_mask[None, :, None] * vision_context
693
+
694
+ # Apply dynamic gate
695
+ text_state = text_state + vision_context * gate_value
696
+
697
+ return text_state
698
+
699
+ def onlytext2media_cross_attn(self,
700
+ text_state,
701
+ text_query,
702
+ vision_features,
703
+ token_type,
704
+ text2vision_cross_attn_mask=None,
705
+ all_text_mask=None):
706
+ '''
707
+ text_query: [bs n h d]
708
+ text_state: [bs n d]
709
+ vision_features: [bs, vis_n, d]
710
+ token_type: [bs, n]
711
+ '''
712
+ # if vision_features is None or (self.is_hyper_enabled == False) or (all_text_mask.sum() == 0):
713
+ if vision_features is None or (self.is_hyper_enabled == False):
714
+ return text_state
715
+
716
+ # select all the pure text token
717
+ pure_text_query = []
718
+ text_mask = ((token_type - 2) <= 0).bool()
719
+
720
+ if "masksystem" in self.cross_attention_implementation:
721
+ new_text_masks = []
722
+ for idx, text_query_ in enumerate(text_query):
723
+ # mask out all the tokens before the media
724
+ first_im_token = (token_type[idx] == 3).nonzero()
725
+ if len(first_im_token) == 0:
726
+ start = 0
727
+ else:
728
+ start = first_im_token[0]
729
+ text_mask_ = text_mask[idx].clone()
730
+ text_mask_[:start] = False
731
+ pure_text_query.append(text_query_[text_mask_])
732
+ new_text_masks.append(text_mask_)
733
+ text_mask = torch.stack(new_text_masks, dim=0)
734
+ else:
735
+ for idx, text_query_ in enumerate(text_query):
736
+ pure_text_query.append(text_query_[text_mask[idx]])
737
+
738
+ # 2. pad all the text tokens
739
+ text_query = torch.nn.utils.rnn.pad_sequence(pure_text_query, batch_first=True)
740
+ padding_attn_mask = torch.ones(text_query.shape[:-2], dtype=torch.bool, device=text_state.device)
741
+ for i, tensor in enumerate(pure_text_query):
742
+ padding_attn_mask[i, len(tensor):] = False # Mark padded elements as False
743
+
744
+ B_c, L_c = text_query.shape[:2]
745
+ D_head = self.head_dim
746
+
747
+ # obtain dynamic gate value
748
+ gate_value = self.cross_attn_gate_proj(text_state[text_mask]) # n, D
749
+ if "warmup" in self.gating_type:
750
+ gate_value = gate_value * self.cross_attn_warm_up_gate.tanh()
751
+
752
+ vision_features = vision_features.contiguous()
753
+ vision_features = self.cross_attn_kv_proj(vision_features)
754
+ text_query = text_query.transpose(1, 2)
755
+
756
+ vision_kv = rearrange(vision_features, 'BL Lv (H KV D) -> KV BL H Lv D', KV=2, H=self.num_key_value_heads)
757
+ vision_key = vision_kv[0].contiguous() # [b h s d]
758
+ vision_value = vision_kv[1].contiguous()
759
+
760
+ vision_key = repeat_kv(vision_key, self.num_key_value_groups)
761
+ vision_value = repeat_kv(vision_value, self.num_key_value_groups)
762
+
763
+ # expend_cross_attn_mask
764
+ attention_mask = text2vision_cross_attn_mask[:, None, :].repeat(1, text_query.shape[2], 1)
765
+ vision_context = self.cross_attn_core_attention(text_query, vision_key, vision_value, attn_mask=attention_mask)
766
+
767
+ # mask out the output if a sample is pure text
768
+ vision_context = all_text_mask[:, None, None] * vision_context
769
+
770
+ # Apply dynamic gate
771
+ extended_attn_output = torch.zeros_like(text_state, dtype=text_state.dtype, device=text_state.device)
772
+ extended_attn_output[text_mask] = extended_attn_output[text_mask] + vision_context[padding_attn_mask] * gate_value
773
+ text_state = text_state + extended_attn_output
774
+ # NOTE Min: just equvalent to the following line. Avoid error under deepspeed zero3
775
+ # text_state[text_mask] = text_state[text_mask] + vision_context[padding_attn_mask] * gate_value
776
+
777
+ return text_state
778
+
779
+
780
+ def forward(
781
+ self,
782
+ hidden_states: torch.Tensor,
783
+ visual_hidden_states: torch.Tensor,
784
+ token_type: torch.Tensor,
785
+ attention_mask: Optional[torch.Tensor] = None,
786
+ text2visual_attention_mask: Optional[torch.Tensor] = None,
787
+ position_ids: Optional[torch.LongTensor] = None,
788
+ past_key_value: Optional[Cache] = None,
789
+ output_attentions: bool = False,
790
+ use_cache: bool = False,
791
+ cache_position: Optional[torch.LongTensor] = None,
792
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
793
+ ):
794
+ bsz, q_len, _ = hidden_states.size()
795
+
796
+ query_states = self.q_proj(hidden_states)
797
+ key_states = self.k_proj(hidden_states)
798
+ value_states = self.v_proj(hidden_states)
799
+
800
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
801
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
802
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
803
+
804
+ if position_embeddings is None:
805
+ # logger.warning_once(
806
+ # "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
807
+ # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
808
+ # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
809
+ # "removed and `position_embeddings` will be mandatory."
810
+ # )
811
+ cos, sin = self.rotary_emb(value_states, position_ids)
812
+ else:
813
+ cos, sin = position_embeddings
814
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
815
+
816
+ if past_key_value is not None:
817
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
818
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
819
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
820
+ if (
821
+ getattr(self.config, "sliding_window", None) is not None
822
+ and kv_seq_len > self.config.sliding_window
823
+ and cache_has_contents
824
+ ):
825
+ slicing_tokens = 1 - self.config.sliding_window
826
+
827
+ past_key = past_key_value[self.layer_idx][0]
828
+ past_value = past_key_value[self.layer_idx][1]
829
+
830
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
831
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
832
+
833
+ if past_key.shape[-2] != self.config.sliding_window - 1:
834
+ raise ValueError(
835
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
836
+ f" {past_key.shape}"
837
+ )
838
+
839
+ if attention_mask is not None:
840
+ attention_mask = attention_mask[:, slicing_tokens:]
841
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
842
+
843
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
844
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
845
+
846
+ # repeat k/v heads if n_kv_heads < n_heads
847
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
848
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
849
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
850
+
851
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
852
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
853
+ # cast them back in float16 just to be sure everything works as expected.
854
+ input_dtype = query_states.dtype
855
+ if input_dtype == torch.float32:
856
+ if torch.is_autocast_enabled():
857
+ target_dtype = torch.get_autocast_gpu_dtype()
858
+ # Handle the case where the model is quantized
859
+ elif hasattr(self.config, "_pre_quantization_dtype"):
860
+ target_dtype = self.config._pre_quantization_dtype
861
+ else:
862
+ target_dtype = self.q_proj.weight.dtype
863
+
864
+ # logger.warning_once(
865
+ # f"The input hidden states seems to be silently casted in float32, this might be related to"
866
+ # f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
867
+ # f" {target_dtype}."
868
+ # )
869
+
870
+ query_states = query_states.to(target_dtype)
871
+ key_states = key_states.to(target_dtype)
872
+ value_states = value_states.to(target_dtype)
873
+
874
+ # Reashape to the expected shape for Flash Attention
875
+ query_states = query_states.transpose(1, 2)
876
+ key_states = key_states.transpose(1, 2)
877
+ value_states = value_states.transpose(1, 2)
878
+
879
+ if (
880
+ self.config.use_sliding_window
881
+ and getattr(self.config, "sliding_window", None) is not None
882
+ and self.layer_idx >= self.config.max_window_layers
883
+ ):
884
+ sliding_window = self.config.sliding_window
885
+ else:
886
+ sliding_window = None
887
+
888
+ attn_output = _flash_attention_forward(
889
+ query_states, # bs, n, head, head_dim
890
+ key_states,
891
+ value_states,
892
+ attention_mask,
893
+ q_len,
894
+ position_ids=position_ids,
895
+ dropout=dropout_rate,
896
+ sliding_window=sliding_window,
897
+ is_causal=self.is_causal,
898
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
899
+ )
900
+
901
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
902
+
903
+ # text-to-image cross-attention
904
+ ####
905
+ all_text_mask = (token_type == 3).sum(dim=-1).bool() # [bs, ] if False, indicate that this sample contains no image input
906
+ if self.cross_attention_implementation.startswith("vanilla"): # all tokens can attend to the slow tokens
907
+ attn_output = self.all2media_cross_attn(attn_output.permute(1, 0, 2),
908
+ query_states.permute(1, 0, 2, 3),
909
+ visual_hidden_states,
910
+ text2visual_attention_mask,
911
+ all_text_mask)
912
+ attn_output = attn_output.permute(1,0,2)
913
+
914
+ elif self.cross_attention_implementation.startswith("text-only-vanilla"): # only text tokens are allowed to attend the slow tokens
915
+ attn_output = self.onlytext2media_cross_attn(attn_output,
916
+ query_states,
917
+ visual_hidden_states,
918
+ token_type=token_type,
919
+ text2vision_cross_attn_mask=text2visual_attention_mask,
920
+ all_text_mask=all_text_mask
921
+ )
922
+ else:
923
+ raise NotImplementedError(f"cross-attention type {self.cross_attention_implementation} not implemented")
924
+ ####
925
+
926
+ attn_output = self.o_proj(attn_output)
927
+
928
+ if not output_attentions:
929
+ attn_weights = None
930
+
931
+ return attn_output, attn_weights, past_key_value
932
+
933
+
934
+ class Qwen2SdpaAttention(Qwen2Attention):
935
+ """
936
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
937
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
938
+ SDPA API.
939
+ """
940
+ # Adapted from Qwen2Attention.forward
941
+ def forward(
942
+ self,
943
+ hidden_states: torch.Tensor,
944
+ attention_mask: Optional[torch.Tensor] = None,
945
+ position_ids: Optional[torch.LongTensor] = None,
946
+ past_key_value: Optional[Cache] = None,
947
+ output_attentions: bool = False,
948
+ use_cache: bool = False,
949
+ cache_position: Optional[torch.LongTensor] = None,
950
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
951
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
952
+ if output_attentions:
953
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
954
+ # logger.warning_once(
955
+ # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
956
+ # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
957
+ # )
958
+ return super().forward(
959
+ hidden_states=hidden_states,
960
+ attention_mask=attention_mask,
961
+ position_ids=position_ids,
962
+ past_key_value=past_key_value,
963
+ output_attentions=output_attentions,
964
+ use_cache=use_cache,
965
+ )
966
+
967
+ bsz, q_len, _ = hidden_states.size()
968
+
969
+ query_states = self.q_proj(hidden_states)
970
+ key_states = self.k_proj(hidden_states)
971
+ value_states = self.v_proj(hidden_states)
972
+
973
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
974
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
975
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
976
+
977
+ if position_embeddings is None:
978
+ # logger.warning_once(
979
+ # "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
980
+ # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
981
+ # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
982
+ # "removed and `position_embeddings` will be mandatory."
983
+ # )
984
+ cos, sin = self.rotary_emb(value_states, position_ids)
985
+ else:
986
+ cos, sin = position_embeddings
987
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
988
+
989
+ if past_key_value is not None:
990
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
991
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
992
+
993
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
994
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
995
+
996
+ causal_mask = attention_mask
997
+ if attention_mask is not None: # no matter the length, we just slice it
998
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
999
+
1000
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1001
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1002
+ if query_states.device.type == "cuda" and attention_mask is not None:
1003
+ query_states = query_states.contiguous()
1004
+ key_states = key_states.contiguous()
1005
+ value_states = value_states.contiguous()
1006
+
1007
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1008
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1009
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1010
+ is_causal = True if causal_mask is None and q_len > 1 else False
1011
+
1012
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1013
+ query_states,
1014
+ key_states,
1015
+ value_states,
1016
+ attn_mask=causal_mask,
1017
+ dropout_p=self.attention_dropout if self.training else 0.0,
1018
+ is_causal=is_causal,
1019
+ )
1020
+
1021
+ attn_output = attn_output.transpose(1, 2).contiguous()
1022
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1023
+
1024
+ attn_output = self.o_proj(attn_output)
1025
+
1026
+ return attn_output, None, past_key_value
1027
+
1028
+ # TODO: Min: Not implementated yet
1029
+ class Qwen2HybridSdpaAttention(Qwen2SdpaAttention):
1030
+ """
1031
+ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
1032
+ `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
1033
+ SDPA API.
1034
+ """
1035
+ def __init__(self,
1036
+ is_hyper_enabled,
1037
+ gating_type,
1038
+ cross_attn_implementation,
1039
+ *args, **kwargs):
1040
+ super().__init__(*args, **kwargs)
1041
+
1042
+ self.is_hyper_enabled = is_hyper_enabled
1043
+
1044
+ if self.is_hyper_enabled:
1045
+ self.gating_type = gating_type
1046
+ self.cross_attention_implementation = cross_attn_implementation
1047
+ self.cross_attn_kv_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim * 2, bias=True)
1048
+
1049
+ if gating_type.startswith("whole-dynamic"):
1050
+ if "tanh" in gating_type:
1051
+ self.cross_attn_gate_proj = nn.Sequential(
1052
+ nn.Linear(self.hidden_size, 1),
1053
+ nn.Tanh()
1054
+ )
1055
+ else:
1056
+ self.cross_attn_gate_proj = nn.Sequential(
1057
+ nn.Linear(self.hidden_size, 1),
1058
+ )
1059
+
1060
+ if gating_type.endswith("warmup"):
1061
+ self.cross_attn_warm_up_gate = torch.nn.Parameter(torch.zeros(1))
1062
+
1063
+ if "flashattn" in self.cross_attention_implementation:
1064
+ self.cross_attn_core_attention = FlashAttnCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout)
1065
+ else:
1066
+ self.cross_attn_core_attention = ScaleDotProductCrossAttention(layer_number=-1, attention_dropout=self.attention_dropout)
1067
+
1068
+ def text2media_cross_attn(self,
1069
+ text_state,
1070
+ text_query,
1071
+ vision_features,
1072
+ text2vision_cross_attn_mask=None,
1073
+ all_text_mask=None):
1074
+ '''
1075
+ text_query: [s b h d]
1076
+ text_state: s b d
1077
+ vision_features: [num_vis, b, d]
1078
+ '''
1079
+ if vision_features is None or (self.is_hyper_enabled == False):
1080
+ return text_state
1081
+
1082
+ # obtain dynamic gate value
1083
+ L_c, B_c = text_state.shape[:2]
1084
+ D_head = self.head_dim
1085
+
1086
+ gate_value = rearrange(
1087
+ self.gate_proj(
1088
+ rearrange(text_state, 'L B (Head D) -> (L B Head) D', D=D_head)),
1089
+ '(L B Head) D -> L B (Head D)', L=L_c, B=B_c)
1090
+
1091
+ vision_features = vision_features.contiguous()
1092
+ vision_features = self.v_kv_proj(vision_features)
1093
+ # length_each_img = vision_features.shape[1]
1094
+ # sequence_length = text_query.shape[0]
1095
+ query_layer = rearrange(query_layer, 'L B H D -> B H L D') # [25, 2, 32, 128])
1096
+
1097
+ vision_kv = rearrange(vision_features, 'BL Lv (H KV D) -> KV 1 H (BL Lv) D', KV=2, H=self.num_key_value_heads)
1098
+ vision_key = vision_kv[0].contiguous() # [b h s d]
1099
+ vision_value = vision_kv[1].contiguous()
1100
+
1101
+ # Apply MI-Rope
1102
+ # key_layer = self.apply_mi_rope(key_layer, media_offset_line=self.visual_cache['media_offset'][batch_id,:,1]-curr_offset[0], length_each_img=length_each_img)
1103
+ key_layer = repeat_kv(key_layer, self.num_key_value_groups)
1104
+ value_layer = repeat_kv(value_layer, self.num_key_value_groups)
1105
+ vision_context = self.v_core_attention_sdpa(query_layer, vision_key, vision_value, attn_mask=None, order='bhsd').squeeze(1) # TODO
1106
+
1107
+ # Apply dynamic gate
1108
+ text_state = text_state * (1 - gate_value) + vision_context * gate_value
1109
+
1110
+ return text_state
1111
+ # Adapted from Qwen2Attention.forward
1112
+ def forward(
1113
+ self,
1114
+ hidden_states: torch.Tensor,
1115
+ visual_hidden_states: torch.Tensor,
1116
+ token_type: torch.Tensor,
1117
+ attention_mask: Optional[torch.Tensor] = None,
1118
+ text2visual_attention_mask: Optional[torch.Tensor] = None,
1119
+ position_ids: Optional[torch.LongTensor] = None,
1120
+ past_key_value: Optional[Cache] = None,
1121
+ output_attentions: bool = False,
1122
+ use_cache: bool = False,
1123
+ cache_position: Optional[torch.LongTensor] = None,
1124
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1125
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1126
+ if output_attentions:
1127
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
1128
+ # logger.warning_once(
1129
+ # "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
1130
+ # 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
1131
+ # )
1132
+ return super().forward(
1133
+ hidden_states=hidden_states,
1134
+ attention_mask=attention_mask,
1135
+ position_ids=position_ids,
1136
+ past_key_value=past_key_value,
1137
+ output_attentions=output_attentions,
1138
+ use_cache=use_cache,
1139
+ )
1140
+
1141
+ bsz, q_len, _ = hidden_states.size()
1142
+
1143
+ query_states = self.q_proj(hidden_states)
1144
+ key_states = self.k_proj(hidden_states)
1145
+ value_states = self.v_proj(hidden_states)
1146
+
1147
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1148
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1149
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1150
+
1151
+ if position_embeddings is None:
1152
+ # logger.warning_once(
1153
+ # "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
1154
+ # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
1155
+ # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
1156
+ # "removed and `position_embeddings` will be mandatory."
1157
+ # )
1158
+ cos, sin = self.rotary_emb(value_states, position_ids)
1159
+ else:
1160
+ cos, sin = position_embeddings
1161
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
1162
+
1163
+ if past_key_value is not None:
1164
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
1165
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
1166
+
1167
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
1168
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
1169
+
1170
+ causal_mask = attention_mask
1171
+ if attention_mask is not None: # no matter the length, we just slice it
1172
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
1173
+
1174
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
1175
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
1176
+ if query_states.device.type == "cuda" and attention_mask is not None:
1177
+ query_states = query_states.contiguous()
1178
+ key_states = key_states.contiguous()
1179
+ value_states = value_states.contiguous()
1180
+
1181
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
1182
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
1183
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
1184
+ is_causal = True if causal_mask is None and q_len > 1 else False
1185
+
1186
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1187
+ query_states,
1188
+ key_states,
1189
+ value_states,
1190
+ attn_mask=causal_mask,
1191
+ dropout_p=self.attention_dropout if self.training else 0.0,
1192
+ is_causal=is_causal,
1193
+ )
1194
+
1195
+ attn_output = attn_output.transpose(1, 2).contiguous()
1196
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
1197
+
1198
+ # text-to-image cross-attention
1199
+ ####
1200
+ all_text_mask = (token_type == 3).sum(dim=-1).bool() # [bs, ] if False, indicate that this sample contains no image input
1201
+ if self.cross_attention_implementation.startswith("vanilla"):
1202
+ attn_output = self.text2media_cross_attn(attn_output.permute(1, 0, 2),
1203
+ query_states.permute(1, 0, 2, 3),
1204
+ visual_hidden_states,
1205
+ text2visual_attention_mask,
1206
+ all_text_mask)
1207
+ attn_output = attn_output.permute(1,0,2)
1208
+
1209
+ elif self.cross_attention_implementation.startswith("text-only-vanilla"):
1210
+ attn_output = self.onlytext2media_cross_attn(attn_output,
1211
+ query_states,
1212
+ visual_hidden_states,
1213
+ token_type=token_type,
1214
+ text2vision_cross_attn_mask=text2visual_attention_mask,
1215
+ all_text_mask=all_text_mask
1216
+ )
1217
+ else:
1218
+ raise NotImplementedError(f"cross-attention type {self.cross_attention_implementation} not implemented")
1219
+ ####
1220
+
1221
+ attn_output = self.o_proj(attn_output)
1222
+
1223
+ return attn_output, None, past_key_value
1224
+
1225
+
1226
+ QWEN2_ATTENTION_CLASSES = {
1227
+ "eager": Qwen2Attention,
1228
+ "flash_attention_2": Qwen2FlashAttention2,
1229
+ "sdpa": Qwen2SdpaAttention,
1230
+ }
1231
+
1232
+ QWEN2_HYBRID_ATTENTION_CLASSES = {
1233
+ "flash_attention_2": Qwen2HybridFlashAttention2,
1234
+ "sdpa": Qwen2HybridSdpaAttention, # Not implemented yet, only support flash attn
1235
+ }
1236
+
1237
+
1238
+ class Qwen2DecoderLayer(nn.Module):
1239
+ def __init__(self, config, layer_idx: int):
1240
+ super().__init__()
1241
+ self.hidden_size = config.hidden_size
1242
+
1243
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
1244
+ # logger.warning_once(
1245
+ # f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
1246
+ # "unexpected results may be encountered."
1247
+ # )
1248
+ pass
1249
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
1250
+
1251
+ self.mlp = Qwen2MLP(config)
1252
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1253
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1254
+
1255
+ def forward(
1256
+ self,
1257
+ hidden_states: torch.Tensor,
1258
+ attention_mask: Optional[torch.Tensor] = None,
1259
+ position_ids: Optional[torch.LongTensor] = None,
1260
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1261
+ output_attentions: Optional[bool] = False,
1262
+ use_cache: Optional[bool] = False,
1263
+ cache_position: Optional[torch.LongTensor] = None,
1264
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1265
+ **kwargs,
1266
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1267
+ """
1268
+ Args:
1269
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1270
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1271
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1272
+ output_attentions (`bool`, *optional*):
1273
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1274
+ returned tensors for more detail.
1275
+ use_cache (`bool`, *optional*):
1276
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1277
+ (see `past_key_values`).
1278
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1279
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1280
+ Indices depicting the position of the input sequence tokens in the sequence.
1281
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1282
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1283
+ with `head_dim` being the embedding dimension of each attention head.
1284
+ kwargs (`dict`, *optional*):
1285
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1286
+ into the model
1287
+ """
1288
+
1289
+ residual = hidden_states
1290
+
1291
+ hidden_states = self.input_layernorm(hidden_states)
1292
+
1293
+ # Self Attention
1294
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1295
+ hidden_states=hidden_states,
1296
+ attention_mask=attention_mask,
1297
+ position_ids=position_ids,
1298
+ past_key_value=past_key_value,
1299
+ output_attentions=output_attentions,
1300
+ use_cache=use_cache,
1301
+ cache_position=cache_position,
1302
+ position_embeddings=position_embeddings,
1303
+ )
1304
+ hidden_states = residual + hidden_states
1305
+
1306
+ # Fully Connected
1307
+ residual = hidden_states
1308
+ hidden_states = self.post_attention_layernorm(hidden_states)
1309
+ hidden_states = self.mlp(hidden_states)
1310
+ hidden_states = residual + hidden_states
1311
+
1312
+ outputs = (hidden_states,)
1313
+
1314
+ if output_attentions:
1315
+ outputs += (self_attn_weights,)
1316
+
1317
+ if use_cache:
1318
+ outputs += (present_key_value,)
1319
+
1320
+ return outputs
1321
+
1322
+
1323
+ class Qwen2HybridDecoderLayer(nn.Module):
1324
+ def __init__(self,
1325
+ config,
1326
+ layer_idx: int,
1327
+ is_hyper_enabled=False,
1328
+ cross_attn_implementation="vanilla", # in ['vanilla' and 'text-only-vanilla']
1329
+ cross_attn_gating_type="channel-wise-dynamic-sigmoid"):
1330
+ super().__init__()
1331
+ self.is_hyper_enabled = is_hyper_enabled
1332
+
1333
+ self.hidden_size = config.hidden_size
1334
+
1335
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
1336
+ # logger.warning_once(
1337
+ # f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
1338
+ # "unexpected results may be encountered."
1339
+ # )
1340
+ pass
1341
+
1342
+ self.self_attn = QWEN2_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config=config,
1343
+ layer_idx=layer_idx,
1344
+ is_hyper_enabled=is_hyper_enabled,
1345
+ cross_attn_implementation=cross_attn_implementation,
1346
+ gating_type=cross_attn_gating_type)
1347
+
1348
+
1349
+ self.mlp = Qwen2MLP(config)
1350
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1351
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1352
+
1353
+ self.gradient_checkpointing = False # move the gradient checkpointing to the forward function of attn and MLP
1354
+
1355
+ # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
1356
+ def condition_vis_x(self,
1357
+ vis_x,
1358
+ cross_attn_mask=None,
1359
+ token_type=None):
1360
+
1361
+ self.vis_x = vis_x
1362
+ self.cross_attn_mask = cross_attn_mask
1363
+ self.media_locations = token_type
1364
+
1365
+ def clear_vis_x(self):
1366
+ self.vis_x = None
1367
+ self.cross_attn_mask = None
1368
+ self.media_locations = None
1369
+
1370
+ def mlp_forward(self, hidden_states):
1371
+ hidden_states = self.post_attention_layernorm(hidden_states)
1372
+ hidden_states = self.mlp(hidden_states)
1373
+ return hidden_states
1374
+
1375
+ def forward(
1376
+ self,
1377
+ hidden_states: torch.Tensor,
1378
+ attention_mask: Optional[torch.Tensor] = None,
1379
+ position_ids: Optional[torch.LongTensor] = None,
1380
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1381
+ output_attentions: Optional[bool] = False,
1382
+ use_cache: Optional[bool] = False,
1383
+ cache_position: Optional[torch.LongTensor] = None,
1384
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
1385
+ **kwargs,
1386
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1387
+ """
1388
+ Args:
1389
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1390
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1391
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1392
+ output_attentions (`bool`, *optional*):
1393
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1394
+ returned tensors for more detail.
1395
+ use_cache (`bool`, *optional*):
1396
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1397
+ (see `past_key_values`).
1398
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1399
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1400
+ Indices depicting the position of the input sequence tokens in the sequence.
1401
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1402
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1403
+ with `head_dim` being the embedding dimension of each attention head.
1404
+ kwargs (`dict`, *optional*):
1405
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1406
+ into the model
1407
+ """
1408
+
1409
+ residual = hidden_states
1410
+
1411
+ hidden_states = self.input_layernorm(hidden_states)
1412
+
1413
+ # process image embedding
1414
+ visual_tokens = self.vis_x
1415
+ cross_attn_mask = self.cross_attn_mask
1416
+ token_type = self.media_locations
1417
+ visual_tokens = self.input_layernorm(visual_tokens)
1418
+
1419
+ # Self Attention
1420
+ if self.gradient_checkpointing and self.training:
1421
+ hidden_states, self_attn_weights, present_key_value = torch.utils.checkpoint.checkpoint(
1422
+ self.self_attn,
1423
+ hidden_states,
1424
+ visual_tokens,
1425
+ token_type,
1426
+ attention_mask,
1427
+ cross_attn_mask,
1428
+ position_ids,
1429
+ past_key_value,
1430
+ output_attentions,
1431
+ use_cache,
1432
+ cache_position,
1433
+ position_embeddings
1434
+ )
1435
+ else:
1436
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
1437
+ hidden_states=hidden_states,
1438
+ attention_mask=attention_mask,
1439
+ visual_hidden_states=visual_tokens,
1440
+ text2visual_attention_mask=cross_attn_mask,
1441
+ token_type=token_type,
1442
+ position_ids=position_ids,
1443
+ past_key_value=past_key_value,
1444
+ output_attentions=output_attentions,
1445
+ use_cache=use_cache,
1446
+ cache_position=cache_position,
1447
+ position_embeddings=position_embeddings,
1448
+ )
1449
+
1450
+ hidden_states = residual + hidden_states
1451
+
1452
+ # Fully Connected
1453
+ residual = hidden_states
1454
+ if self.gradient_checkpointing and self.training:
1455
+ hidden_states = torch.utils.checkpoint.checkpoint(
1456
+ self.mlp_forward,
1457
+ hidden_states)
1458
+ else:
1459
+ hidden_states = self.mlp_forward(hidden_states)
1460
+
1461
+ hidden_states = residual + hidden_states
1462
+
1463
+ outputs = (hidden_states,)
1464
+
1465
+ if output_attentions:
1466
+ outputs += (self_attn_weights,)
1467
+
1468
+ if use_cache:
1469
+ outputs += (present_key_value,)
1470
+
1471
+ return outputs
1472
+
1473
+
llava/model/language_model/llava_qwen2.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+
21
+ import transformers
22
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+
27
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
29
+
30
+
31
+ class LlavaQwenConfig(Qwen2Config):
32
+ model_type = "llava_qwen"
33
+
34
+
35
+ class LlavaQwenModel(LlavaMetaModel, Qwen2Model):
36
+ config_class = LlavaQwenConfig
37
+
38
+ def __init__(self, config: Qwen2Config):
39
+ super(LlavaQwenModel, self).__init__(config)
40
+
41
+
42
+ class LlavaQwenForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
43
+ config_class = LlavaQwenConfig
44
+
45
+ def __init__(self, config):
46
+ Qwen2ForCausalLM.__init__(self, config)
47
+ config.model_type = "llava_qwen"
48
+ config.rope_scaling = None
49
+
50
+ self.model = LlavaQwenModel(config)
51
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
52
+ self.post_init()
53
+
54
+ def get_model(self):
55
+ return self.model
56
+
57
+ def forward(
58
+ self,
59
+ input_ids: torch.LongTensor = None,
60
+ attention_mask: Optional[torch.Tensor] = None,
61
+ position_ids: Optional[torch.LongTensor] = None,
62
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
63
+ inputs_embeds: Optional[torch.FloatTensor] = None,
64
+ labels: Optional[torch.LongTensor] = None,
65
+ use_cache: Optional[bool] = None,
66
+ output_attentions: Optional[bool] = None,
67
+ output_hidden_states: Optional[bool] = None,
68
+ images: Optional[torch.FloatTensor] = None,
69
+ image_sizes: Optional[List[List[int]]] = None,
70
+ return_dict: Optional[bool] = None,
71
+ modalities: Optional[List[str]] = ["image"],
72
+ dpo_forward: Optional[bool] = False,
73
+ cache_position=None,
74
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
75
+
76
+ if inputs_embeds is None:
77
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
78
+
79
+ if dpo_forward:
80
+ outputs = self.model(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask,
83
+ position_ids=position_ids,
84
+ past_key_values=past_key_values,
85
+ inputs_embeds=inputs_embeds,
86
+ use_cache=use_cache,
87
+ output_attentions=output_attentions,
88
+ output_hidden_states=output_hidden_states,
89
+ return_dict=return_dict,
90
+ )
91
+
92
+ hidden_states = outputs[0]
93
+ logits = self.lm_head(hidden_states)
94
+ return logits, labels
95
+
96
+ else:
97
+ return super().forward(
98
+ input_ids=input_ids,
99
+ attention_mask=attention_mask,
100
+ position_ids=position_ids,
101
+ past_key_values=past_key_values,
102
+ inputs_embeds=inputs_embeds,
103
+ labels=labels,
104
+ use_cache=use_cache,
105
+ output_attentions=output_attentions,
106
+ output_hidden_states=output_hidden_states,
107
+ return_dict=return_dict,
108
+ )
109
+
110
+ @torch.no_grad()
111
+ def generate(
112
+ self,
113
+ inputs: Optional[torch.Tensor] = None,
114
+ images: Optional[torch.Tensor] = None,
115
+ image_sizes: Optional[torch.Tensor] = None,
116
+ modalities: Optional[List[str]] = ["image"],
117
+ **kwargs,
118
+ ) -> Union[GenerateOutput, torch.LongTensor]:
119
+ position_ids = kwargs.pop("position_ids", None)
120
+ attention_mask = kwargs.pop("attention_mask", None)
121
+ if "inputs_embeds" in kwargs:
122
+ raise NotImplementedError("`inputs_embeds` is not supported")
123
+
124
+ if images is not None:
125
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
126
+ else:
127
+ inputs_embeds = self.get_model().embed_tokens(inputs)
128
+
129
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
130
+
131
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
132
+ images = kwargs.pop("images", None)
133
+ image_sizes = kwargs.pop("image_sizes", None)
134
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
135
+ if images is not None:
136
+ inputs["images"] = images
137
+ if image_sizes is not None:
138
+ inputs["image_sizes"] = image_sizes
139
+ return inputs
140
+
141
+
142
+ AutoConfig.register("llava_qwen", LlavaQwenConfig)
143
+ AutoModelForCausalLM.register(LlavaQwenConfig, LlavaQwenForCausalLM)
llava/model/language_model/llava_qwen2_slowfast.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Hao Zhang
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union, Dict
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import CrossEntropyLoss
20
+ from collections import OrderedDict
21
+
22
+ import transformers
23
+ # from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast, CausalLMOutputWithPast
25
+ from transformers.generation.utils import GenerateOutput
26
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
27
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
28
+
29
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
30
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
31
+ from llava.model.language_model.hybrid_decoder_layer import Qwen2DecoderLayer, Qwen2HybridDecoderLayer
32
+
33
+ class LlavaQwenSlowFastConfig(Qwen2Config):
34
+ model_type = "llava_qwen_slow_fast"
35
+
36
+
37
+ class LlavaQwenSlowFastModel(LlavaMetaModel, Qwen2Model):
38
+ config_class = LlavaQwenSlowFastConfig
39
+
40
+ def __init__(self, config: Qwen2Config):
41
+ super(LlavaQwenSlowFastModel, self).__init__(config)
42
+
43
+ # initialize the cross-attention layers
44
+ self.slow_branch_is_initialized = False
45
+
46
+ if hasattr(config, "cross_attn_every_n_layers"):
47
+ self.initialize_slow_branch_modules(config)
48
+
49
+ def initialize_slow_branch_modules(self, args):
50
+ if self.slow_branch_is_initialized:
51
+ return
52
+ # number of decoder layers
53
+ num_layers = len(self.layers)
54
+
55
+ cross_attn_every_n_layers = args.cross_attn_every_n_layers
56
+ cross_attn_gating_type = args.cross_attn_gating_type
57
+ cross_attn_implementation = args.cross_attn_implementation
58
+ cross_attn_max_layer_depth = getattr(args, "cross_attn_max_layer_depth", num_layers)
59
+ cross_attn_min_layer_depth = getattr(args, "cross_attn_min_layer_depth", 0)
60
+ if cross_attn_max_layer_depth is None:
61
+ cross_attn_max_layer_depth = num_layers
62
+ if cross_attn_min_layer_depth is None:
63
+ cross_attn_min_layer_depth = 0
64
+
65
+ self.config.cross_attn_every_n_layers = cross_attn_every_n_layers
66
+ self.config.cross_attn_implementation = cross_attn_implementation
67
+ self.config.cross_attn_gating_type = cross_attn_gating_type
68
+ self.config.cross_attn_max_layer_depth = cross_attn_max_layer_depth
69
+ self.config.cross_attn_min_layer_depth = cross_attn_min_layer_depth
70
+
71
+ # set pooling operations
72
+ tile_image_input = getattr(args, "tile_image_input", True) # tile all the image input into a video sequence
73
+ min_fast_frames = getattr(args, "min_fast_frames", 1) # force to sample at least `min_fast_frames` frames for fast visual tokens
74
+ if min_fast_frames is None:
75
+ min_fast_frames = 1
76
+
77
+ self.config.tile_image_input = tile_image_input
78
+ self.config.min_fast_frames = min_fast_frames
79
+
80
+ # generate layer index for the hybrid layer
81
+ hybrid_layer_idx = []
82
+ for i in range(cross_attn_min_layer_depth, cross_attn_max_layer_depth, cross_attn_every_n_layers):
83
+ hybrid_layer_idx.append(i)
84
+
85
+ # substitute the original decoder layer with hybrid layer
86
+ initialize_kv_from_lm = getattr(args, "initialize_cross_attn_kv_from_lm", False) # whether use LLM's pretrained kv projection to initialize the kv projection weight of cross-attn
87
+ for idx in range(len(self.layers)):
88
+ if idx in hybrid_layer_idx:
89
+ original_decoder_layer = self.layers[idx]
90
+ hybrid_decoder_layer = Qwen2HybridDecoderLayer(self.config, layer_idx=idx, is_hyper_enabled=True, cross_attn_gating_type=cross_attn_gating_type, cross_attn_implementation=cross_attn_implementation)
91
+ _, unexpected_keys = hybrid_decoder_layer.load_state_dict(original_decoder_layer.state_dict(), strict=False) # cause problem when using deepspeed zero3
92
+ if initialize_kv_from_lm and hasattr(hybrid_decoder_layer.self_attn, "cross_attn_kv_proj"):
93
+ kv_weight = torch.cat([original_decoder_layer.self_attn.k_proj.weight,
94
+ original_decoder_layer.self_attn.v_proj.weight], dim=0)
95
+ kv_bias = torch.cat([original_decoder_layer.self_attn.k_proj.bias,
96
+ original_decoder_layer.self_attn.v_proj.bias], dim=0)
97
+ new_state_dict = OrderedDict()
98
+ new_state_dict['weight'] = kv_weight
99
+ new_state_dict['bias'] = kv_bias
100
+ hybrid_decoder_layer.self_attn.cross_attn_kv_proj.load_state_dict(new_state_dict)
101
+ assert len(unexpected_keys) == 0
102
+ self.layers[idx] = hybrid_decoder_layer
103
+
104
+ # fast token config
105
+ self.config.fast_token_spatial_stride = args.fast_token_spatial_stride
106
+ self.config.fast_token_temporal_stride = args.fast_token_temporal_stride
107
+ self.config.fast_token_temporal_sampling_stride = args.fast_token_temporal_sampling_stride
108
+
109
+ self.slow_branch_is_initialized = True
110
+
111
+ def forward(
112
+ self,
113
+ input_ids: torch.LongTensor = None,
114
+ attention_mask: Optional[torch.Tensor] = None,
115
+ position_ids: Optional[torch.LongTensor] = None,
116
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
117
+ inputs_embeds: Optional[torch.FloatTensor] = None,
118
+ use_cache: Optional[bool] = None,
119
+ output_attentions: Optional[bool] = None,
120
+ output_hidden_states: Optional[bool] = None,
121
+ return_dict: Optional[bool] = None,
122
+ cache_position: Optional[torch.LongTensor] = None,
123
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
124
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
125
+ output_hidden_states = (
126
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
127
+ )
128
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
129
+
130
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
131
+
132
+ if (input_ids is None) ^ (inputs_embeds is not None):
133
+ raise ValueError(
134
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
135
+ )
136
+
137
+ if self.gradient_checkpointing and self.training:
138
+ if use_cache:
139
+ # logger.warning_once(
140
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
141
+ # )
142
+ use_cache = False
143
+
144
+ # kept for BC (non `Cache` `past_key_values` inputs)
145
+ return_legacy_cache = False
146
+ if use_cache and not isinstance(past_key_values, Cache):
147
+ return_legacy_cache = True
148
+ if past_key_values is None:
149
+ past_key_values = DynamicCache()
150
+ else:
151
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
152
+ # logger.warning_once(
153
+ # "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
154
+ # "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
155
+ # "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
156
+ # )
157
+
158
+ if inputs_embeds is None:
159
+ inputs_embeds = self.embed_tokens(input_ids)
160
+
161
+ if cache_position is None:
162
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
163
+ cache_position = torch.arange(
164
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
165
+ )
166
+ if position_ids is None:
167
+ position_ids = cache_position.unsqueeze(0)
168
+
169
+ causal_mask = self._update_causal_mask(
170
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
171
+ )
172
+
173
+ hidden_states = inputs_embeds
174
+
175
+ # create position embeddings to be shared across the decoder layers
176
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
177
+
178
+ # decoder layers
179
+ all_hidden_states = () if output_hidden_states else None
180
+ all_self_attns = () if output_attentions else None
181
+ next_decoder_cache = None
182
+
183
+ for decoder_layer in self.layers:
184
+ if output_hidden_states:
185
+ all_hidden_states += (hidden_states,)
186
+
187
+ if self.gradient_checkpointing and self.training:
188
+ if not isinstance(decoder_layer, Qwen2HybridDecoderLayer):
189
+ layer_outputs = self._gradient_checkpointing_func(
190
+ decoder_layer.__call__,
191
+ hidden_states,
192
+ causal_mask,
193
+ position_ids,
194
+ past_key_values,
195
+ output_attentions,
196
+ use_cache,
197
+ cache_position,
198
+ position_embeddings,
199
+ )
200
+ else:
201
+ layer_outputs = decoder_layer(
202
+ hidden_states,
203
+ attention_mask=causal_mask,
204
+ position_ids=position_ids,
205
+ past_key_value=past_key_values,
206
+ output_attentions=output_attentions,
207
+ use_cache=use_cache,
208
+ cache_position=cache_position,
209
+ position_embeddings=position_embeddings,
210
+ )
211
+
212
+ hidden_states = layer_outputs[0]
213
+
214
+ if use_cache:
215
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
216
+
217
+ if output_attentions:
218
+ all_self_attns += (layer_outputs[1],)
219
+
220
+ hidden_states = self.norm(hidden_states)
221
+
222
+ # add hidden states from the last decoder layer
223
+ if output_hidden_states:
224
+ all_hidden_states += (hidden_states,)
225
+
226
+ next_cache = next_decoder_cache if use_cache else None
227
+ if return_legacy_cache:
228
+ next_cache = next_cache.to_legacy_cache()
229
+
230
+ if not return_dict:
231
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
232
+ return BaseModelOutputWithPast(
233
+ last_hidden_state=hidden_states,
234
+ past_key_values=next_cache,
235
+ hidden_states=all_hidden_states,
236
+ attentions=all_self_attns,
237
+ )
238
+
239
+
240
+
241
+ class LlavaQwenSlowFastForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
242
+ config_class = LlavaQwenSlowFastConfig
243
+
244
+ def __init__(self, config):
245
+ Qwen2ForCausalLM.__init__(self, config)
246
+ config.model_type = "llava_qwen_slow_fast"
247
+ config.rope_scaling = None
248
+
249
+ self.model = LlavaQwenSlowFastModel(config)
250
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
251
+ self.post_init()
252
+
253
+ def get_model(self):
254
+ return self.model
255
+
256
+ def _set_gradient_checkpointing(self, module, value=False):
257
+ if isinstance(module, Qwen2HybridDecoderLayer):
258
+ module.gradient_checkpointing = value
259
+
260
+ def forward(
261
+ self,
262
+ input_ids: torch.LongTensor = None,
263
+ attention_mask: Optional[torch.Tensor] = None,
264
+ position_ids: Optional[torch.LongTensor] = None,
265
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
266
+ inputs_embeds: Optional[torch.FloatTensor] = None,
267
+ labels: Optional[torch.LongTensor] = None,
268
+ use_cache: Optional[bool] = None,
269
+ output_attentions: Optional[bool] = None,
270
+ output_hidden_states: Optional[bool] = None,
271
+ images: Optional[torch.FloatTensor] = None,
272
+ image_sizes: Optional[List[List[int]]] = None,
273
+ return_dict: Optional[bool] = None,
274
+ modalities: Optional[List[str]] = ["image"],
275
+ cache_position=None,
276
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
277
+
278
+ if inputs_embeds is None:
279
+ (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = self.prepare_inputs_labels_for_multimodal(input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
280
+
281
+ return super().forward(
282
+ input_ids=input_ids,
283
+ attention_mask=attention_mask,
284
+ position_ids=position_ids,
285
+ past_key_values=past_key_values,
286
+ inputs_embeds=inputs_embeds,
287
+ labels=labels,
288
+ use_cache=use_cache,
289
+ output_attentions=output_attentions,
290
+ output_hidden_states=output_hidden_states,
291
+ return_dict=return_dict,
292
+ )
293
+
294
+ @torch.no_grad()
295
+ def generate(
296
+ self,
297
+ inputs: Optional[torch.Tensor] = None,
298
+ images: Optional[torch.Tensor] = None,
299
+ image_sizes: Optional[torch.Tensor] = None,
300
+ modalities: Optional[List[str]] = ["image"],
301
+ **kwargs,
302
+ ) -> Union[GenerateOutput, torch.LongTensor]:
303
+ position_ids = kwargs.pop("position_ids", None)
304
+ attention_mask = kwargs.pop("attention_mask", None)
305
+ if "inputs_embeds" in kwargs:
306
+ raise NotImplementedError("`inputs_embeds` is not supported")
307
+
308
+ if images is not None:
309
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes)
310
+ else:
311
+ inputs_embeds = self.get_model().embed_tokens(inputs)
312
+
313
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
314
+
315
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
316
+ images = kwargs.pop("images", None)
317
+ image_sizes = kwargs.pop("image_sizes", None)
318
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
319
+ if images is not None:
320
+ inputs["images"] = images
321
+ if image_sizes is not None:
322
+ inputs["image_sizes"] = image_sizes
323
+ return inputs
324
+
325
+ def sample_fast_frames(self,
326
+ total_frames,
327
+ stride,
328
+ min_frame_number):
329
+
330
+ all_indices_list = list(range(total_frames))
331
+
332
+ if total_frames < min_frame_number:
333
+ return all_indices_list
334
+
335
+ sampled_frames = max(total_frames // stride, min_frame_number)
336
+ stride = total_frames / sampled_frames
337
+
338
+ fast_indices = [min(int(i * stride), total_frames-1) for i in range(sampled_frames)]
339
+
340
+ return fast_indices
341
+
342
+ def split_slow_fast_tokens(self,
343
+ visual_tokens,
344
+ temporal_sampling_stride=1,
345
+ spatial_stride=1,
346
+ temporal_stride=1):
347
+ # TODO: Min: this function is very messy and can be simplified.
348
+ if isinstance(visual_tokens, torch.Tensor):
349
+ # for all image inputs, only perform spatial pooling
350
+ b, n, c = visual_tokens.shape
351
+ h = w = int(n**0.5)
352
+ fast_visual_tokens = nn.functional.avg_pool2d(visual_tokens.reshape(b, h, w, c).permute(0, 3, 1, 2),
353
+ kernel_size=spatial_stride,
354
+ stride=spatial_stride).flatten(2,3).transpose(1,2)
355
+ return fast_visual_tokens, visual_tokens
356
+ else:
357
+ fast_visual_tokens = []
358
+ for sample_ in visual_tokens:
359
+ t, n, c = sample_.shape
360
+ if t > 1: # is a video
361
+ T_downsampling_rate = temporal_sampling_stride * temporal_stride
362
+
363
+ if t % T_downsampling_rate != 0:
364
+ padding_size = (T_downsampling_rate - t % T_downsampling_rate) % T_downsampling_rate
365
+ # Pad on the first dimension (sequence length) with zeros
366
+ sample_ = nn.functional.pad(sample_, (0, 0, 0, 0, 0, padding_size)) # (dim_pad_left, dim_pad_right, T_pad_left, T_pad_right)
367
+
368
+ # 1. temporal direct sampling
369
+ if temporal_sampling_stride > 1:
370
+ fast_token_indices = self.sample_fast_frames(total_frames=t,
371
+ stride=temporal_sampling_stride,
372
+ min_frame_number=self.config.min_fast_frames)
373
+ else:
374
+ fast_token_indices = list(range(t))
375
+
376
+ sample_ = torch.stack([sample_[idx] for idx in fast_token_indices], dim=0)
377
+ b, n, c = sample_.shape
378
+ h = w = int(n**0.5)
379
+ sample_ = sample_.reshape(b, h, w, c).permute(0, 3, 1, 2)
380
+
381
+ # 2. temporal average pooling
382
+ if temporal_stride > 1:
383
+ if (sample_.shape[0] // temporal_stride) >= self.config.min_fast_frames:
384
+ sample_ = nn.functional.avg_pool3d(sample_.transpose(0, 1), kernel_size=(temporal_stride, 1, 1)).transpose(0, 1)
385
+ else:
386
+ h_, w_ = sample_.shape[-2:]
387
+ output_frames_num = min(sample_.shape[0], self.config.min_fast_frames)
388
+ sample_ = nn.functional.adaptive_avg_pool3d(sample_.transpose(0, 1), output_size=(output_frames_num, h_, w_)).transpose(0, 1)
389
+
390
+ # 3. spatial pooling
391
+ if spatial_stride > 1:
392
+ sample_ = nn.functional.avg_pool2d(sample_,
393
+ kernel_size=spatial_stride,
394
+ stride=spatial_stride)
395
+ sample_ = sample_.flatten(2,3).transpose(1,2)
396
+
397
+ else:
398
+ if spatial_stride > 1:
399
+ h = w = int(n**0.5)
400
+ sample_ = sample_.reshape(t, h, w, c).permute(0, 3, 1, 2)
401
+ sample_ = nn.functional.avg_pool2d(sample_,
402
+ kernel_size=spatial_stride,
403
+ stride=spatial_stride)
404
+ sample_ = sample_.flatten(2,3).transpose(1,2)
405
+
406
+ fast_visual_tokens.append(sample_.flatten(0, 1).contiguous())
407
+ slow_visual_tokens = [_.flatten(0, 1).contiguous() for _ in visual_tokens]
408
+
409
+ return fast_visual_tokens, slow_visual_tokens
410
+
411
+
412
+ def prepare_inputs_labels_for_multimodal(
413
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
414
+ images, image_sizes=None
415
+ ):
416
+ vision_tower = self.get_vision_tower()
417
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
418
+ # clear the visual tokens if current one is a pure text sample
419
+ if images is None and input_ids.shape[1] > 1:
420
+ for layer in self.get_decoder().layers:
421
+ if hasattr(layer, "clear_vis_x"):
422
+ layer.clear_vis_x()
423
+
424
+ token_types = torch.ones_like(input_ids, dtype=input_ids.dtype, device=input_ids.device)
425
+ for layer in self.get_decoder().layers:
426
+ if hasattr(layer, "condition_vis_x"):
427
+ layer.media_locations = token_types
428
+
429
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
430
+
431
+ # handle image input
432
+ images = [image if len(image.shape) == 4 else image.unsqueeze(0) for image in images] # list [ [T, C, H, W], ]
433
+ feature_split_size = [x.shape[0] for x in images]
434
+ all_features, feature_split_size = self.encode_images(torch.cat(images, dim=0), feature_split_size)
435
+
436
+ raw_image_features = torch.split(all_features, feature_split_size, dim=0)
437
+ image_features = []
438
+ for sample_feat in raw_image_features: # initial spatial pooling for all video tokens
439
+ if sample_feat.shape[0] > 1 and self.config.mm_video_pooling_stride > 1:
440
+ b, n, c = sample_feat.shape
441
+ h = w = int(n**0.5)
442
+ sample_feat = nn.functional.avg_pool2d(sample_feat.reshape(b, h, w, c).permute(0, 3, 1, 2),
443
+ kernel_size=self.config.mm_video_pooling_stride,
444
+ stride=self.config.mm_video_pooling_stride).flatten(2,3).transpose(1,2)
445
+ image_features.append(sample_feat.contiguous())
446
+ del raw_image_features, all_features
447
+
448
+ ## generate fast and slow tokens
449
+ image_features, slow_image_features = self.split_slow_fast_tokens(image_features,
450
+ temporal_sampling_stride=self.config.fast_token_temporal_sampling_stride,
451
+ spatial_stride=self.config.fast_token_spatial_stride,
452
+ temporal_stride=self.config.fast_token_temporal_stride)
453
+
454
+ ## set cross-attention states
455
+ if isinstance(slow_image_features, (list, tuple)):
456
+ padded_tensors = torch.nn.utils.rnn.pad_sequence(slow_image_features, batch_first=True)
457
+ cross_attn_mask = torch.ones(padded_tensors.shape[:-1], dtype=torch.bool, device=padded_tensors.device)
458
+ for i, tensor in enumerate(slow_image_features):
459
+ cross_attn_mask[i, len(tensor):] = False # Mark padded elements as False
460
+ slow_image_features = padded_tensors
461
+ else:
462
+ cross_attn_mask = torch.ones(slow_image_features.shape[:-1], dtype=torch.bool, device=slow_image_features.device)
463
+
464
+ # TODO: image start / end is not implemented here to support pretraining.
465
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
466
+ raise NotImplementedError
467
+
468
+ # Let's just add dummy tensors if they do not exist,
469
+ # it is a headache to deal with None all the time.
470
+ # But it is not ideal, and if you have a better idea,
471
+ # please open an issue / submit a PR, thanks.
472
+ _labels = labels
473
+ _position_ids = position_ids
474
+ _attention_mask = attention_mask
475
+ if attention_mask is None:
476
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
477
+ else:
478
+ attention_mask = attention_mask.bool()
479
+ if position_ids is None:
480
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
481
+ if labels is None:
482
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
483
+
484
+ # remove the padding using attention_mask -- FIXME
485
+ _input_ids = input_ids
486
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
487
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
488
+
489
+ new_input_embeds = []
490
+ new_labels = []
491
+ cur_image_idx = 0
492
+ new_token_types = []
493
+ # NOTE: Min: we need to record the type of tokens so that we can split the tokens in the hybrid decoder layer
494
+ # Token type 1: user's input and system tokens, 2: response text tokens, 3: visual tokens, 4: invalid tokens (padding)
495
+
496
+ for batch_idx, cur_input_ids in enumerate(input_ids):
497
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
498
+ if num_images == 0:
499
+ cur_image_features = image_features[cur_image_idx]
500
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
501
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
502
+ new_input_embeds.append(cur_input_embeds)
503
+ new_labels.append(labels[batch_idx])
504
+
505
+ cur_token_type = torch.full((cur_input_ids.shape[0],), 2, dtype=cur_input_ids[-1].dtype, device=cur_input_ids[-1].device)
506
+ cur_token_type[labels[batch_idx] == IGNORE_INDEX] = 1 # token with ignore tokens are considered as user input
507
+ new_token_types.append(cur_token_type)
508
+
509
+ cur_image_idx += 1
510
+ continue
511
+
512
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
513
+ cur_input_ids_noim = []
514
+ cur_labels = labels[batch_idx]
515
+ cur_labels_noim = []
516
+ cur_token_type_noim = []
517
+
518
+ for i in range(len(image_token_indices) - 1):
519
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
520
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
521
+
522
+ cur_token = torch.full((cur_labels_noim[-1].shape[0],), 2, dtype=cur_input_ids_noim[-1].dtype, device=cur_input_ids_noim[-1].device)
523
+ cur_token[cur_labels[image_token_indices[i]+1:image_token_indices[i+1]] == IGNORE_INDEX] = 1 # ingored tokens are considered as user input
524
+ cur_token_type_noim.append(cur_token)
525
+
526
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
527
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
528
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
529
+ cur_new_input_embeds = []
530
+ cur_new_labels = []
531
+ cur_new_token_type = []
532
+
533
+ for i in range(num_images + 1):
534
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
535
+ cur_new_labels.append(cur_labels_noim[i])
536
+ cur_new_token_type.append(cur_token_type_noim[i])
537
+
538
+ if i < num_images:
539
+ cur_image_features = image_features[cur_image_idx]
540
+ cur_image_idx += 1
541
+ cur_new_input_embeds.append(cur_image_features)
542
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
543
+ cur_new_token_type.append(torch.full((cur_image_features.shape[0],), 3, device=cur_labels.device, dtype=cur_labels.dtype)) # insert image token type
544
+
545
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
546
+
547
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
548
+ cur_new_labels = torch.cat(cur_new_labels)
549
+ cur_new_token_type = torch.cat(cur_new_token_type) ##
550
+
551
+ new_input_embeds.append(cur_new_input_embeds)
552
+ new_labels.append(cur_new_labels)
553
+ new_token_types.append(cur_new_token_type) ##
554
+
555
+ # Truncate sequences to max length as image embeddings can make the sequence longer
556
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
557
+ if tokenizer_model_max_length is not None:
558
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
559
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
560
+ new_token_types = [x[:tokenizer_model_max_length] for x in new_token_types]
561
+
562
+ # Combine them
563
+ max_len = max(x.shape[0] for x in new_input_embeds)
564
+ batch_size = len(new_input_embeds)
565
+
566
+ new_input_embeds_padded = []
567
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
568
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
569
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
570
+ new_token_types_padded = torch.full((batch_size, max_len), 4, dtype=new_labels[0].dtype, device=new_labels[0].device) ## 4 is invalid token type (padding)
571
+
572
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
573
+ cur_len = cur_new_embed.shape[0]
574
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
575
+ new_input_embeds_padded.append(torch.cat((
576
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
577
+ cur_new_embed
578
+ ), dim=0))
579
+ if cur_len > 0:
580
+ new_labels_padded[i, -cur_len:] = cur_new_labels
581
+ attention_mask[i, -cur_len:] = True
582
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
583
+ new_token_types_padded[i, -cur_len:] = new_token_types[i] ##
584
+ else:
585
+ new_input_embeds_padded.append(torch.cat((
586
+ cur_new_embed,
587
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
588
+ ), dim=0))
589
+ if cur_len > 0:
590
+ new_labels_padded[i, :cur_len] = cur_new_labels
591
+ attention_mask[i, :cur_len] = True
592
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
593
+ new_token_types_padded[i, :cur_len] = new_token_types[i]
594
+
595
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
596
+
597
+ if _labels is None:
598
+ new_labels = None
599
+ else:
600
+ new_labels = new_labels_padded
601
+
602
+ if _attention_mask is None:
603
+ attention_mask = None
604
+ else:
605
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
606
+
607
+ if _position_ids is None:
608
+ position_ids = None
609
+
610
+ # token type
611
+ token_types = new_token_types_padded
612
+ # send token type to cross-attn layers
613
+ if _input_ids is not None and _input_ids.shape[-1] == 1:
614
+ pass
615
+ else:
616
+ if slow_image_features is not None:
617
+ for layer in self.get_decoder().layers:
618
+ if hasattr(layer, "condition_vis_x"):
619
+ layer.condition_vis_x(slow_image_features,
620
+ cross_attn_mask,
621
+ token_type=token_types)
622
+ else:
623
+ for layer in self.get_decoder().layers:
624
+ if hasattr(layer, "clear_vis_x"):
625
+ layer.clear_vis_x()
626
+
627
+
628
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
629
+
630
+
631
+ AutoConfig.register("llava_qwen_slow_fast", LlavaQwenSlowFastConfig)
632
+ AutoModelForCausalLM.register(LlavaQwenSlowFastConfig, LlavaQwenSlowFastForCausalLM)
llava/model/llava_arch.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ # Copyright 2023 Haotian Liu
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from abc import ABC, abstractmethod
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from .multimodal_encoder.builder import build_vision_tower
24
+ from .multimodal_projector.builder import build_vision_projector
25
+
26
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
27
+
28
+ from llava.mm_utils import get_anyres_image_grid_shape
29
+
30
+
31
+ class LlavaMetaModel:
32
+
33
+ def __init__(self, config):
34
+ super(LlavaMetaModel, self).__init__(config)
35
+
36
+ if hasattr(config, "mm_vision_tower"):
37
+ self.vision_tower = build_vision_tower(config, delay_load=True)
38
+ fpn_input_dim = [] if not hasattr(self.vision_tower, "fpn_input_dim") else self.vision_tower.fpn_input_dim
39
+ self.mm_projector = build_vision_projector(config, fpn_input_dim=fpn_input_dim)
40
+
41
+ if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
42
+ self.image_newline = nn.Parameter(
43
+ torch.empty(config.hidden_size, dtype=self.dtype)
44
+ )
45
+
46
+ def get_vision_tower(self):
47
+ vision_tower = getattr(self, 'vision_tower', None)
48
+ if type(vision_tower) is list:
49
+ vision_tower = vision_tower[0]
50
+ return vision_tower
51
+
52
+ def initialize_vision_modules(self, model_args, fsdp=None):
53
+ vision_tower = model_args.vision_tower
54
+ mm_vision_select_layer = model_args.mm_vision_select_layer
55
+ mm_vision_select_feature = model_args.mm_vision_select_feature
56
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
57
+ mm_patch_merge_type = model_args.mm_patch_merge_type
58
+ mm_video_pooling_stride = model_args.mm_video_pooling_stride
59
+
60
+ self.config.mm_vision_tower = vision_tower
61
+
62
+ if self.get_vision_tower() is None:
63
+ vision_tower = build_vision_tower(model_args)
64
+
65
+ if fsdp is not None and len(fsdp) > 0:
66
+ self.vision_tower = [vision_tower]
67
+ else:
68
+ self.vision_tower = vision_tower
69
+ else:
70
+ if fsdp is not None and len(fsdp) > 0:
71
+ vision_tower = self.vision_tower[0]
72
+ else:
73
+ vision_tower = self.vision_tower
74
+ vision_tower.load_model()
75
+
76
+ self.config.use_mm_proj = True
77
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
78
+ self.config.mm_hidden_size = vision_tower.hidden_size
79
+ self.config.mm_vision_select_layer = mm_vision_select_layer
80
+ self.config.mm_vision_select_feature = mm_vision_select_feature
81
+ self.config.mm_patch_merge_type = mm_patch_merge_type
82
+ self.config.mm_video_pooling_stride = mm_video_pooling_stride
83
+
84
+ if getattr(self, 'mm_projector', None) is None:
85
+ fpn_input_dim = [] if not hasattr(self.vision_tower, "fpn_input_dim") else self.vision_tower.fpn_input_dim
86
+ self.mm_projector = build_vision_projector(self.config, fpn_input_dim=fpn_input_dim)
87
+
88
+ if 'unpad' in mm_patch_merge_type:
89
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
90
+ self.image_newline = nn.Parameter(
91
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
92
+ )
93
+ else:
94
+ # In case it is frozen by LoRA
95
+ for p in self.mm_projector.parameters():
96
+ p.requires_grad = True
97
+
98
+ if pretrain_mm_mlp_adapter is not None:
99
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
100
+ def get_w(weights, keyword):
101
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
102
+
103
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
104
+
105
+
106
+ def unpad_image(tensor, original_size):
107
+ """
108
+ Unpads a PyTorch tensor of a padded and resized image.
109
+
110
+ Args:
111
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
112
+ original_size (tuple): The original size of the image (height, width).
113
+
114
+ Returns:
115
+ torch.Tensor: The unpadded image tensor.
116
+ """
117
+ original_width, original_height = original_size
118
+ current_height, current_width = tensor.shape[1:]
119
+
120
+ original_aspect_ratio = original_width / original_height
121
+ current_aspect_ratio = current_width / current_height
122
+
123
+ if original_aspect_ratio > current_aspect_ratio:
124
+ scale_factor = current_width / original_width
125
+ new_height = int(original_height * scale_factor)
126
+ padding = (current_height - new_height) // 2
127
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
128
+ else:
129
+ scale_factor = current_height / original_height
130
+ new_width = int(original_width * scale_factor)
131
+ padding = (current_width - new_width) // 2
132
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
133
+
134
+ return unpadded_tensor
135
+
136
+
137
+ class LlavaMetaForCausalLM(ABC):
138
+
139
+ @abstractmethod
140
+ def get_model(self):
141
+ pass
142
+
143
+ def get_vision_tower(self):
144
+ return self.get_model().get_vision_tower()
145
+
146
+ def encode_images(self, images, feat_split_size=None):
147
+ image_features = self.get_model().get_vision_tower()(images)
148
+
149
+ if "st" in self.config.mm_projector_type:
150
+ # need temporal correlations
151
+ if feat_split_size is None:
152
+ feat_split_size = [1] * images.shape[0]
153
+ image_features = image_features.split(feat_split_size)
154
+ image_features = self.get_model().mm_projector(image_features)
155
+ feat_split_size = [_.shape[0] for _ in image_features]
156
+ image_features = torch.cat(image_features, dim=0)
157
+
158
+ else:
159
+ image_features = self.get_model().mm_projector(image_features)
160
+ return image_features, feat_split_size
161
+
162
+ def prepare_inputs_labels_for_multimodal(
163
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
164
+ images, image_sizes=None
165
+ ):
166
+ vision_tower = self.get_vision_tower()
167
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
168
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
169
+
170
+ if type(images) is list: # for each element, either video tensor [T, C, H, W], or image tensor [C, H, W]
171
+
172
+ images = [image if len(image.shape) == 4 else image.unsqueeze(0) for image in images] # list [ [T, C, H, W], ]
173
+ feature_split_size = [x.shape[0] for x in images]
174
+ all_features, feature_split_size = self.encode_images(torch.cat(images, dim=0), feature_split_size)
175
+ image_features_raw = torch.split(all_features, feature_split_size, dim=0)
176
+
177
+ image_features = []
178
+ for sample_feat in image_features_raw:
179
+ if sample_feat.shape[0] > 1 and self.config.mm_video_pooling_stride > 1: # is video and use different pooling
180
+ b, n, c = sample_feat.shape
181
+ h = w = int(n**0.5)
182
+ sample_feat = nn.functional.avg_pool2d(sample_feat.reshape(b, h, w, c).permute(0, 3, 1, 2),
183
+ kernel_size=self.config.mm_video_pooling_stride,
184
+ stride=self.config.mm_video_pooling_stride).flatten(2,3).transpose(1,2)
185
+
186
+ image_features.append(sample_feat.flatten(0,1).contiguous())
187
+
188
+ else:
189
+ image_features, feature_split_size = self.encode_images(images)
190
+
191
+ # TODO: image start / end is not implemented here to support pretraining.
192
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
193
+ raise NotImplementedError
194
+
195
+ # Let's just add dummy tensors if they do not exist,
196
+ # it is a headache to deal with None all the time.
197
+ # But it is not ideal, and if you have a better idea,
198
+ # please open an issue / submit a PR, thanks.
199
+ _labels = labels
200
+ _position_ids = position_ids
201
+ _attention_mask = attention_mask
202
+ if attention_mask is None:
203
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
204
+ else:
205
+ attention_mask = attention_mask.bool()
206
+ if position_ids is None:
207
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
208
+ if labels is None:
209
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
210
+
211
+ # remove the padding using attention_mask -- FIXME
212
+ _input_ids = input_ids
213
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
214
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
215
+
216
+ new_input_embeds = []
217
+ new_labels = []
218
+ cur_image_idx = 0
219
+ for batch_idx, cur_input_ids in enumerate(input_ids):
220
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
221
+ if num_images == 0:
222
+ cur_image_features = image_features[cur_image_idx]
223
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
224
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
225
+ new_input_embeds.append(cur_input_embeds)
226
+ new_labels.append(labels[batch_idx])
227
+ cur_image_idx += 1
228
+ continue
229
+
230
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
231
+ cur_input_ids_noim = []
232
+ cur_labels = labels[batch_idx]
233
+ cur_labels_noim = []
234
+ for i in range(len(image_token_indices) - 1):
235
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
236
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
237
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
238
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
239
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
240
+ cur_new_input_embeds = []
241
+ cur_new_labels = []
242
+
243
+ for i in range(num_images + 1):
244
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
245
+ cur_new_labels.append(cur_labels_noim[i])
246
+ if i < num_images:
247
+ cur_image_features = image_features[cur_image_idx]
248
+ cur_image_idx += 1
249
+ cur_new_input_embeds.append(cur_image_features)
250
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
251
+
252
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
253
+
254
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
255
+ cur_new_labels = torch.cat(cur_new_labels)
256
+
257
+ new_input_embeds.append(cur_new_input_embeds)
258
+ new_labels.append(cur_new_labels)
259
+
260
+ # Truncate sequences to max length as image embeddings can make the sequence longer
261
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
262
+ if tokenizer_model_max_length is not None:
263
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
264
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
265
+
266
+ # Combine them
267
+ max_len = max(x.shape[0] for x in new_input_embeds)
268
+ batch_size = len(new_input_embeds)
269
+
270
+ new_input_embeds_padded = []
271
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
272
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
273
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
274
+
275
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
276
+ cur_len = cur_new_embed.shape[0]
277
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
278
+ new_input_embeds_padded.append(torch.cat((
279
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
280
+ cur_new_embed
281
+ ), dim=0))
282
+ if cur_len > 0:
283
+ new_labels_padded[i, -cur_len:] = cur_new_labels
284
+ attention_mask[i, -cur_len:] = True
285
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
286
+ else:
287
+ new_input_embeds_padded.append(torch.cat((
288
+ cur_new_embed,
289
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
290
+ ), dim=0))
291
+ if cur_len > 0:
292
+ new_labels_padded[i, :cur_len] = cur_new_labels
293
+ attention_mask[i, :cur_len] = True
294
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
295
+
296
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
297
+
298
+ if _labels is None:
299
+ new_labels = None
300
+ else:
301
+ new_labels = new_labels_padded
302
+
303
+ if _attention_mask is None:
304
+ attention_mask = None
305
+ else:
306
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
307
+
308
+ if _position_ids is None:
309
+ position_ids = None
310
+
311
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
312
+
313
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
314
+ if model_args.mm_use_im_patch_token:
315
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
316
+ self.resize_token_embeddings(len(tokenizer))
317
+
318
+ if model_args.mm_use_im_start_end:
319
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
320
+ self.resize_token_embeddings(len(tokenizer))
321
+
322
+ if num_new_tokens > 0:
323
+ input_embeddings = self.get_input_embeddings().weight.data
324
+ output_embeddings = self.get_output_embeddings().weight.data
325
+
326
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
327
+ dim=0, keepdim=True)
328
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
329
+ dim=0, keepdim=True)
330
+
331
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
332
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
333
+
334
+ if model_args.tune_mm_mlp_adapter:
335
+ for p in self.get_input_embeddings().parameters():
336
+ p.requires_grad = True
337
+ for p in self.get_output_embeddings().parameters():
338
+ p.requires_grad = False
339
+
340
+ if model_args.pretrain_mm_mlp_adapter:
341
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
342
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
343
+ assert num_new_tokens == 2
344
+ if input_embeddings.shape == embed_tokens_weight.shape:
345
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
346
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
347
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
348
+ else:
349
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
350
+ elif model_args.mm_use_im_patch_token:
351
+ if model_args.tune_mm_mlp_adapter:
352
+ for p in self.get_input_embeddings().parameters():
353
+ p.requires_grad = False
354
+ for p in self.get_output_embeddings().parameters():
355
+ p.requires_grad = False
llava/model/multimodal_encoder/builder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ import os
4
+ from .clip_encoder import CLIPVisionTower, SiglipVisionTower
5
+ from copy import deepcopy
6
+ from .convnext_encoder import ConvNextVisionTower
7
+
8
+ def build_vision_tower(vision_tower_cfg, **kwargs):
9
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
10
+
11
+ if "clip" in vision_tower and vision_tower.startswith("openai"):
12
+ is_absolute_path_exists = os.path.exists(vision_tower)
13
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
14
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
15
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
16
+
17
+ elif "siglip" in vision_tower:
18
+ vision_tower_cfg.input_image_size = 384
19
+ return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
20
+
21
+ elif vision_tower == "convnext-576":
22
+ ## ConvNeXt
23
+ convnext_args = deepcopy(vision_tower_cfg)
24
+ convnext_args.freeze_vision = False
25
+ convnext_args.input_image_size = 576
26
+ convnext_vision_tower = "convnext_xxlarge.clip_laion2b_soup" # hardcode
27
+ return ConvNextVisionTower(convnext_vision_tower, convnext_args)
28
+
29
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
llava/model/multimodal_encoder/clip_encoder.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import (
6
+ CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig,
7
+ SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig
8
+ )
9
+
10
+
11
+ class CLIPVisionTower(nn.Module):
12
+ def __init__(self, vision_tower, args, delay_load=False):
13
+ super().__init__()
14
+
15
+ self.is_loaded = False
16
+
17
+ self.vision_tower_name = vision_tower
18
+ self.select_layer = args.mm_vision_select_layer
19
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
20
+
21
+ if not delay_load:
22
+ self.load_model()
23
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
24
+ self.load_model()
25
+ else:
26
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
27
+
28
+ def load_model(self, device_map=None):
29
+ if self.is_loaded:
30
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
31
+ return
32
+
33
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
34
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
35
+ self.vision_tower.requires_grad_(False)
36
+
37
+ self.is_loaded = True
38
+
39
+ def feature_select(self, image_forward_outs):
40
+ image_features = image_forward_outs.hidden_states[self.select_layer]
41
+ if self.select_feature == 'patch':
42
+ image_features = image_features[:, 1:]
43
+ elif self.select_feature == 'cls_patch':
44
+ image_features = image_features
45
+ else:
46
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
47
+ return image_features
48
+
49
+ @torch.no_grad()
50
+ def forward(self, images):
51
+ if type(images) is list:
52
+ image_features = []
53
+ for image in images:
54
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
55
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
56
+ image_features.append(image_feature)
57
+ else:
58
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
59
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
60
+
61
+ return image_features
62
+
63
+ @property
64
+ def dummy_feature(self):
65
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
66
+
67
+ @property
68
+ def dtype(self):
69
+ return self.vision_tower.dtype
70
+
71
+ @property
72
+ def device(self):
73
+ return self.vision_tower.device
74
+
75
+ @property
76
+ def config(self):
77
+ if self.is_loaded:
78
+ return self.vision_tower.config
79
+ else:
80
+ return self.cfg_only
81
+
82
+ @property
83
+ def hidden_size(self):
84
+ return self.config.hidden_size
85
+
86
+ @property
87
+ def num_patches_per_side(self):
88
+ return self.config.image_size // self.config.patch_size
89
+
90
+ @property
91
+ def num_patches(self):
92
+ return (self.config.image_size // self.config.patch_size) ** 2
93
+
94
+ @property
95
+ def image_size(self):
96
+ return self.config.image_size
97
+
98
+
99
+ class SiglipVisionTower(nn.Module):
100
+
101
+ def __init__(self, vision_tower, args, delay_load=False):
102
+ super().__init__()
103
+
104
+ self.is_loaded = False
105
+ self.vision_tower_name = vision_tower
106
+ self.select_layer = args.mm_vision_select_layer
107
+ self.input_image_size = args.input_image_size
108
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
109
+ self.is_loaded = False
110
+
111
+ if not delay_load:
112
+ self.load_model()
113
+ else:
114
+ self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
115
+
116
+
117
+ def load_model(self, device_map=None):
118
+ if self. is_loaded:
119
+ return
120
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
121
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
122
+ self.image_processor.crop_size = {'height':self.input_image_size, 'width':self.input_image_size}
123
+ self.is_loaded = True
124
+
125
+ def feature_select(self, image_forward_outs, dtype):
126
+ image_features = image_forward_outs.hidden_states
127
+ if self.select_feature == 'patch':
128
+ image_features = image_features[self.select_layer].to(dtype)
129
+ elif self.select_feature == 'list':
130
+ image_features = [feature.to(dtype) for feature in image_features[::7]]
131
+ else:
132
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
133
+ return image_features
134
+
135
+ def forward(self, images):
136
+ if type(images) is list:
137
+ image_features = []
138
+ for image in images:
139
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
140
+ image_feature = self.feature_select(image_forward_out, image.dtype)
141
+ image_features.append(image_feature)
142
+ else:
143
+ batch_size = images.shape[0]
144
+ chunk_size = 256
145
+ image_features = []
146
+
147
+ for i in range(0, batch_size, chunk_size):
148
+ chunk = images[i:i+chunk_size].to(device=self.device, dtype=self.dtype)
149
+ chunk_forward_outs = self.vision_tower(chunk, output_hidden_states=True)
150
+ chunk_features = self.feature_select(chunk_forward_outs, images.dtype)
151
+ image_features.append(chunk_features)
152
+
153
+ image_features = torch.cat(image_features, dim=0)
154
+
155
+ return image_features
156
+
157
+ @property
158
+ def dummy_feature(self):
159
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
160
+
161
+ @property
162
+ def dtype(self):
163
+ return self.vision_tower.dtype
164
+
165
+ @property
166
+ def device(self):
167
+ return self.vision_tower.device
168
+
169
+ @property
170
+ def config(self):
171
+ if self.is_loaded:
172
+ return self.vision_tower.config
173
+ else:
174
+ return self.cfg_only
175
+
176
+ @property
177
+ def hidden_size(self):
178
+ return self.config.hidden_size
179
+
180
+ @property
181
+ def num_patches(self):
182
+ return (self.config.image_size // self.config.patch_size) ** 2
183
+
184
+ @property
185
+ def num_patches_per_side(self):
186
+ return self.config.image_size // self.config.patch_size
187
+
188
+ @property
189
+ def image_size(self):
190
+ return self.config.image_size
llava/model/multimodal_encoder/convnext_encoder.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/luogen1996/LLaVA-HR
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import CLIPImageProcessor
6
+ from .vision_models.convnext import convnext_xxlarge
7
+ from torch.utils.checkpoint import checkpoint
8
+ from llava.utils import load_state_dict_into_model
9
+
10
+ import safetensors
11
+ from collections import OrderedDict
12
+
13
+ cfg={
14
+ "crop_size": 256,
15
+ "do_center_crop": True,
16
+ "do_normalize": True,
17
+ "do_resize": True,
18
+ "feature_extractor_type": "CLIPFeatureExtractor",
19
+ "image_mean": [
20
+ 0.48145466,
21
+ 0.4578275,
22
+ 0.40821073
23
+ ],
24
+ "image_std": [
25
+ 0.26862954,
26
+ 0.26130258,
27
+ 0.27577711
28
+ ],
29
+ "resample": 3,
30
+ "size": 256
31
+ }
32
+
33
+ class ConvNextVisionTower(nn.Module):
34
+ def __init__(self, vision_tower, args, delay_load=False):
35
+ super().__init__()
36
+
37
+ self.is_loaded = False
38
+ self.freeze_vision=args.freeze_vision
39
+ self.input_image_size=args.input_image_size
40
+ self.vision_tower_name = vision_tower
41
+ self.select_layer = -1 # hardcode
42
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
43
+
44
+ self.load_model()
45
+
46
+ def load_model(self):
47
+ self.image_processor = CLIPImageProcessor(**cfg)
48
+ if 'xxlarge' in self.vision_tower_name:
49
+ self.vision_tower = convnext_xxlarge(self.vision_tower_name)
50
+ setattr(self.vision_tower, 'hidden_size', 3072)
51
+
52
+ # load weights manually to avoid deepspeed issue
53
+ # encoder_ckpt = safetensors.torch.load_file("checkpoints/convnext-xxl-clip.safetensors", device="cpu")
54
+ # new_dict = OrderedDict()
55
+ # for k, v in encoder_ckpt.items():
56
+ # if "gamma" in k:
57
+ # k = k.replace("gamma", "weight")
58
+ # new_dict[k] = v
59
+ # encoder_ckpt = new_dict
60
+
61
+ # load_state_dict_into_model(model_to_load=self.vision_tower,
62
+ # state_dict=encoder_ckpt)
63
+
64
+ else:
65
+ raise NotImplementedError
66
+
67
+ if self.freeze_vision:
68
+ self.vision_tower.requires_grad_(False)
69
+
70
+ # Hardcode
71
+ for s in self.vision_tower.stages:
72
+ s.grad_checkpointing = True
73
+
74
+ if self.input_image_size is not None:
75
+ self.image_processor.size=self.input_image_size
76
+ self.image_processor.crop_size={
77
+ 'height':self.input_image_size,
78
+ 'width': self.input_image_size
79
+ }
80
+
81
+ self.is_loaded = True
82
+
83
+ def feature_select(self, image_forward_outs):
84
+ image_features = image_forward_outs[self.select_layer]
85
+ return image_features
86
+
87
+ def forward_features(self, x):
88
+ x = self.vision_tower.stem(x)
89
+ image_forward_out=[]
90
+ for blk in self.vision_tower.stages:
91
+ x = blk(x)
92
+ b,c,h,w=x.shape
93
+ image_forward_out.append(x.view(b,c,-1).transpose(1,2))
94
+ return image_forward_out
95
+
96
+ def forward(self, images):
97
+ if self.freeze_vision:
98
+ with torch.no_grad():
99
+ image_features = self._forward_images(images)
100
+ else:
101
+ image_features = self._forward_images(images)
102
+
103
+ return image_features
104
+
105
+ def _forward_images(self, images):
106
+
107
+ image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
108
+ image_features = self.feature_select(image_forward_outs)
109
+
110
+ return image_features
111
+
112
+ @property
113
+ def dummy_feature(self):
114
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
115
+
116
+ @property
117
+ def dtype(self):
118
+ return next(self.vision_tower.parameters()).dtype
119
+
120
+ @property
121
+ def device(self):
122
+ return next(self.vision_tower.parameters()).device
123
+
124
+ @property
125
+ def config(self):
126
+ assert NotImplementedError
127
+ pass
128
+
129
+ @property
130
+ def num_attention_heads(self):
131
+ # as constant
132
+ return 16
133
+ @property
134
+ def num_layers(self):
135
+ # as constant
136
+ return 4
137
+ @property
138
+ def hidden_size(self):
139
+ return self.vision_tower.hidden_size
140
+
141
+ @property
142
+ def num_patches(self):
143
+ return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2
llava/model/multimodal_encoder/vision_models/__init__.py ADDED
File without changes
llava/model/multimodal_encoder/vision_models/convnext.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ConvNeXt
2
+
3
+ Papers:
4
+ * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
5
+ @Article{liu2022convnet,
6
+ author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
7
+ title = {A ConvNet for the 2020s},
8
+ journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
9
+ year = {2022},
10
+ }
11
+
12
+ * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
13
+ @article{Woo2023ConvNeXtV2,
14
+ title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
15
+ author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
16
+ year={2023},
17
+ journal={arXiv preprint arXiv:2301.00808},
18
+ }
19
+
20
+ Original code and weights from:
21
+ * https://github.com/facebookresearch/ConvNeXt, original copyright below
22
+ * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
23
+
24
+ Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
25
+
26
+ Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
27
+ """
28
+ # ConvNeXt
29
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
30
+ # All rights reserved.
31
+ # This source code is licensed under the MIT license
32
+
33
+ # ConvNeXt-V2
34
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
35
+ # All rights reserved.
36
+ # This source code is licensed under the license found in the
37
+ # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
38
+ # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
39
+
40
+ from collections import OrderedDict
41
+ from functools import partial
42
+ from typing import Callable, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+
47
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
48
+ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
49
+ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
50
+ from timm.layers import NormMlpClassifierHead, ClassifierHead
51
+ from timm.models._builder import build_model_with_cfg
52
+ from timm.models._manipulate import named_apply, checkpoint_seq
53
+ from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations
54
+
55
+ __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
56
+
57
+
58
+ class Downsample(nn.Module):
59
+
60
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
61
+ super().__init__()
62
+ avg_stride = stride if dilation == 1 else 1
63
+ if stride > 1 or dilation > 1:
64
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
65
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
66
+ else:
67
+ self.pool = nn.Identity()
68
+
69
+ if in_chs != out_chs:
70
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
71
+ else:
72
+ self.conv = nn.Identity()
73
+
74
+ def forward(self, x):
75
+ x = self.pool(x)
76
+ x = self.conv(x)
77
+ return x
78
+
79
+
80
+ class ConvNeXtBlock(nn.Module):
81
+ """ ConvNeXt Block
82
+ There are two equivalent implementations:
83
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
84
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
85
+
86
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
87
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
88
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ in_chs: int,
94
+ out_chs: Optional[int] = None,
95
+ kernel_size: int = 7,
96
+ stride: int = 1,
97
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
98
+ mlp_ratio: float = 4,
99
+ conv_mlp: bool = False,
100
+ conv_bias: bool = True,
101
+ use_grn: bool = False,
102
+ ls_init_value: Optional[float] = 1e-6,
103
+ act_layer: Union[str, Callable] = 'gelu',
104
+ norm_layer: Optional[Callable] = None,
105
+ drop_path: float = 0.,
106
+ ):
107
+ """
108
+
109
+ Args:
110
+ in_chs: Block input channels.
111
+ out_chs: Block output channels (same as in_chs if None).
112
+ kernel_size: Depthwise convolution kernel size.
113
+ stride: Stride of depthwise convolution.
114
+ dilation: Tuple specifying input and output dilation of block.
115
+ mlp_ratio: MLP expansion ratio.
116
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
117
+ conv_bias: Apply bias for all convolution (linear) layers.
118
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
119
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
120
+ act_layer: Activation layer.
121
+ norm_layer: Normalization layer (defaults to LN if not specified).
122
+ drop_path: Stochastic depth probability.
123
+ """
124
+ super().__init__()
125
+ out_chs = out_chs or in_chs
126
+ dilation = to_ntuple(2)(dilation)
127
+ act_layer = get_act_layer(act_layer)
128
+ if not norm_layer:
129
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
130
+ mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
131
+ self.use_conv_mlp = conv_mlp
132
+ self.conv_dw = create_conv2d(
133
+ in_chs,
134
+ out_chs,
135
+ kernel_size=kernel_size,
136
+ stride=stride,
137
+ dilation=dilation[0],
138
+ depthwise=True,
139
+ bias=conv_bias,
140
+ )
141
+ self.norm = norm_layer(out_chs)
142
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
143
+ self.weight = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
144
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
145
+ self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
146
+ else:
147
+ self.shortcut = nn.Identity()
148
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
149
+
150
+ def forward(self, x):
151
+ shortcut = x
152
+ x = self.conv_dw(x)
153
+ if self.use_conv_mlp:
154
+ x = self.norm(x)
155
+ x = self.mlp(x)
156
+ else:
157
+ x = x.permute(0, 2, 3, 1)
158
+ x = self.norm(x)
159
+ x = self.mlp(x)
160
+ x = x.permute(0, 3, 1, 2)
161
+ if self.weight is not None:
162
+ x = x.mul(self.weight.reshape(1, -1, 1, 1))
163
+
164
+ x = self.drop_path(x) + self.shortcut(shortcut)
165
+ return x
166
+
167
+
168
+ class ConvNeXtStage(nn.Module):
169
+
170
+ def __init__(
171
+ self,
172
+ in_chs,
173
+ out_chs,
174
+ kernel_size=7,
175
+ stride=2,
176
+ depth=2,
177
+ dilation=(1, 1),
178
+ drop_path_rates=None,
179
+ ls_init_value=1.0,
180
+ conv_mlp=False,
181
+ conv_bias=True,
182
+ use_grn=False,
183
+ act_layer='gelu',
184
+ norm_layer=None,
185
+ norm_layer_cl=None
186
+ ):
187
+ super().__init__()
188
+ self.grad_checkpointing = False
189
+
190
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
191
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
192
+ pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
193
+ self.downsample = nn.Sequential(
194
+ norm_layer(in_chs),
195
+ create_conv2d(
196
+ in_chs,
197
+ out_chs,
198
+ kernel_size=ds_ks,
199
+ stride=stride,
200
+ dilation=dilation[0],
201
+ padding=pad,
202
+ bias=conv_bias,
203
+ ),
204
+ )
205
+ in_chs = out_chs
206
+ else:
207
+ self.downsample = nn.Identity()
208
+
209
+ drop_path_rates = drop_path_rates or [0.] * depth
210
+ stage_blocks = []
211
+ for i in range(depth):
212
+ stage_blocks.append(ConvNeXtBlock(
213
+ in_chs=in_chs,
214
+ out_chs=out_chs,
215
+ kernel_size=kernel_size,
216
+ dilation=dilation[1],
217
+ drop_path=drop_path_rates[i],
218
+ ls_init_value=ls_init_value,
219
+ conv_mlp=conv_mlp,
220
+ conv_bias=conv_bias,
221
+ use_grn=use_grn,
222
+ act_layer=act_layer,
223
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
224
+ ))
225
+ in_chs = out_chs
226
+ self.blocks = nn.Sequential(*stage_blocks)
227
+
228
+ def forward(self, x):
229
+ x = self.downsample(x)
230
+ if self.grad_checkpointing and self.training and not torch.jit.is_scripting():
231
+ x = checkpoint_seq(self.blocks, x)
232
+ else:
233
+ x = self.blocks(x)
234
+ return x
235
+
236
+
237
+ class ConvNeXt(nn.Module):
238
+ r""" ConvNeXt
239
+ A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ in_chans: int = 3,
245
+ num_classes: int = 1000,
246
+ global_pool: str = 'avg',
247
+ output_stride: int = 32,
248
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
249
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
250
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
251
+ ls_init_value: Optional[float] = 1e-6,
252
+ stem_type: str = 'patch',
253
+ patch_size: int = 4,
254
+ head_init_scale: float = 1.,
255
+ head_norm_first: bool = False,
256
+ head_hidden_size: Optional[int] = None,
257
+ conv_mlp: bool = False,
258
+ conv_bias: bool = True,
259
+ use_grn: bool = False,
260
+ act_layer: Union[str, Callable] = 'gelu',
261
+ norm_layer: Optional[Union[str, Callable]] = None,
262
+ norm_eps: Optional[float] = None,
263
+ drop_rate: float = 0.,
264
+ drop_path_rate: float = 0.,
265
+ ):
266
+ """
267
+ Args:
268
+ in_chans: Number of input image channels.
269
+ num_classes: Number of classes for classification head.
270
+ global_pool: Global pooling type.
271
+ output_stride: Output stride of network, one of (8, 16, 32).
272
+ depths: Number of blocks at each stage.
273
+ dims: Feature dimension at each stage.
274
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
275
+ ls_init_value: Init value for Layer Scale, disabled if None.
276
+ stem_type: Type of stem.
277
+ patch_size: Stem patch size for patch stem.
278
+ head_init_scale: Init scaling value for classifier weights and biases.
279
+ head_norm_first: Apply normalization before global pool + head.
280
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
281
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
282
+ conv_bias: Use bias layers w/ all convolutions.
283
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
284
+ act_layer: Activation layer type.
285
+ norm_layer: Normalization layer type.
286
+ drop_rate: Head pre-classifier dropout rate.
287
+ drop_path_rate: Stochastic depth drop rate.
288
+ """
289
+ super().__init__()
290
+ assert output_stride in (8, 16, 32)
291
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
292
+ if norm_layer is None:
293
+ norm_layer = LayerNorm2d
294
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
295
+ if norm_eps is not None:
296
+ norm_layer = partial(norm_layer, eps=norm_eps)
297
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
298
+ else:
299
+ assert conv_mlp,\
300
+ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
301
+ norm_layer_cl = norm_layer
302
+ if norm_eps is not None:
303
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
304
+
305
+ self.num_classes = num_classes
306
+ self.drop_rate = drop_rate
307
+ self.feature_info = []
308
+
309
+ assert stem_type in ('patch', 'overlap', 'overlap_tiered')
310
+ if stem_type == 'patch':
311
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312
+ self.stem = nn.Sequential(
313
+ nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
314
+ norm_layer(dims[0]),
315
+ )
316
+ stem_stride = patch_size
317
+ else:
318
+ mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
319
+ self.stem = nn.Sequential(
320
+ nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
321
+ nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
322
+ norm_layer(dims[0]),
323
+ )
324
+ stem_stride = 4
325
+
326
+ self.stages = nn.Sequential()
327
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
328
+ stages = []
329
+ prev_chs = dims[0]
330
+ curr_stride = stem_stride
331
+ dilation = 1
332
+ # 4 feature resolution stages, each consisting of multiple residual blocks
333
+ for i in range(4):
334
+ stride = 2 if curr_stride == 2 or i > 0 else 1
335
+ if curr_stride >= output_stride and stride > 1:
336
+ dilation *= stride
337
+ stride = 1
338
+ curr_stride *= stride
339
+ first_dilation = 1 if dilation in (1, 2) else 2
340
+ out_chs = dims[i]
341
+ stages.append(ConvNeXtStage(
342
+ prev_chs,
343
+ out_chs,
344
+ kernel_size=kernel_sizes[i],
345
+ stride=stride,
346
+ dilation=(first_dilation, dilation),
347
+ depth=depths[i],
348
+ drop_path_rates=dp_rates[i],
349
+ ls_init_value=ls_init_value,
350
+ conv_mlp=conv_mlp,
351
+ conv_bias=conv_bias,
352
+ use_grn=use_grn,
353
+ act_layer=act_layer,
354
+ norm_layer=norm_layer,
355
+ norm_layer_cl=norm_layer_cl,
356
+ ))
357
+ prev_chs = out_chs
358
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
359
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
360
+ self.stages = nn.Sequential(*stages)
361
+ self.num_features = prev_chs
362
+
363
+ # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
364
+ # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
365
+ if head_norm_first:
366
+ assert not head_hidden_size
367
+ self.norm_pre = norm_layer(self.num_features)
368
+ self.head = ClassifierHead(
369
+ self.num_features,
370
+ num_classes,
371
+ pool_type=global_pool,
372
+ drop_rate=self.drop_rate,
373
+ )
374
+ else:
375
+ self.norm_pre = nn.Identity()
376
+ self.head = NormMlpClassifierHead(
377
+ self.num_features,
378
+ num_classes,
379
+ hidden_size=head_hidden_size,
380
+ pool_type=global_pool,
381
+ drop_rate=self.drop_rate,
382
+ norm_layer=norm_layer,
383
+ act_layer='gelu',
384
+ )
385
+ # named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
386
+
387
+ @torch.jit.ignore
388
+ def group_matcher(self, coarse=False):
389
+ return dict(
390
+ stem=r'^stem',
391
+ blocks=r'^stages\.(\d+)' if coarse else [
392
+ (r'^stages\.(\d+)\.downsample', (0,)), # blocks
393
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
394
+ (r'^norm_pre', (99999,))
395
+ ]
396
+ )
397
+
398
+ @torch.jit.ignore
399
+ def set_grad_checkpointing(self, enable=True):
400
+ for s in self.stages:
401
+ s.grad_checkpointing = enable
402
+
403
+ @torch.jit.ignore
404
+ def get_classifier(self):
405
+ return self.head.fc
406
+
407
+ def reset_classifier(self, num_classes=0, global_pool=None):
408
+ self.head.reset(num_classes, global_pool)
409
+
410
+ def forward_features(self, x):
411
+ x = self.stem(x)
412
+ x = self.stages(x)
413
+ x = self.norm_pre(x)
414
+ return x
415
+
416
+ def forward_head(self, x, pre_logits: bool = False):
417
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
418
+
419
+ def forward(self, x):
420
+ x = self.forward_features(x)
421
+ x = self.forward_head(x)
422
+ return x
423
+
424
+
425
+ def _init_weights(module, name=None, head_init_scale=1.0):
426
+ if isinstance(module, nn.Conv2d):
427
+ trunc_normal_(module.weight, std=.02)
428
+ if module.bias is not None:
429
+ nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.Linear):
431
+ trunc_normal_(module.weight, std=.02)
432
+ nn.init.zeros_(module.bias)
433
+ if name and 'head.' in name:
434
+ module.weight.data.mul_(head_init_scale)
435
+ module.bias.data.mul_(head_init_scale)
436
+
437
+
438
+ def checkpoint_filter_fn(state_dict, model):
439
+ """ Remap FB checkpoints -> timm """
440
+ if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
441
+ out_dict={}
442
+ out_dict = {k.replace('gamma', 'weight'): v for k, v in state_dict.items()}
443
+ return out_dict # non-FB checkpoint
444
+ if 'model' in state_dict:
445
+ state_dict = state_dict['model']
446
+
447
+ out_dict = {}
448
+ if 'visual.trunk.stem.0.weight' in state_dict:
449
+ out_dict = {k.replace('visual.trunk.', '').replace('gamma', 'weight'): v for k, v in state_dict.items() if
450
+ k.startswith('visual.trunk.')}
451
+
452
+ if 'visual.head.proj.weight' in state_dict:
453
+ out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
454
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
455
+ elif 'visual.head.mlp.fc1.weight' in state_dict:
456
+ out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
457
+ out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
458
+ out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
459
+ out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
460
+ return out_dict
461
+
462
+ import re
463
+ for k, v in state_dict.items():
464
+ k = k.replace('downsample_layers.0.', 'stem.')
465
+ k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
466
+ k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
467
+ k = k.replace('dwconv', 'conv_dw')
468
+ k = k.replace('pwconv', 'mlp.fc')
469
+ if 'grn' in k:
470
+ k = k.replace('grn.beta', 'mlp.grn.bias')
471
+ k = k.replace('grn.gamma', 'mlp.grn.weight')
472
+ v = v.reshape(v.shape[-1])
473
+ k = k.replace('head.', 'head.fc.')
474
+ if k.startswith('norm.'):
475
+ k = k.replace('norm', 'head.norm')
476
+ if v.ndim == 2 and 'head' not in k:
477
+ model_shape = model.state_dict()[k].shape
478
+ v = v.reshape(model_shape)
479
+ k=k.replace('gamma','weight')
480
+ out_dict[k] = v
481
+
482
+ return out_dict
483
+
484
+
485
+ def _create_convnext(variant, pretrained=False, **kwargs):
486
+ if kwargs.get('pretrained_cfg', '') == 'fcmae':
487
+ # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
488
+ # This is workaround loading with num_classes=0 w/o removing norm-layer.
489
+ kwargs.setdefault('pretrained_strict', False)
490
+
491
+ model = build_model_with_cfg(
492
+ ConvNeXt, variant, pretrained,
493
+ pretrained_filter_fn=checkpoint_filter_fn,
494
+ feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
495
+ **kwargs)
496
+ return model
497
+
498
+
499
+ def _cfg(url='', **kwargs):
500
+ return {
501
+ 'url': url,
502
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
503
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
504
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
505
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
506
+ **kwargs
507
+ }
508
+
509
+
510
+ def _cfgv2(url='', **kwargs):
511
+ return {
512
+ 'url': url,
513
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
514
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
515
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
516
+ 'first_conv': 'stem.0', 'classifier': 'head.fc',
517
+ 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
518
+ 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
519
+ 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
520
+ **kwargs
521
+ }
522
+
523
+
524
+ default_cfgs = generate_default_cfgs({
525
+ # timm specific variants
526
+ 'convnext_tiny.in12k_ft_in1k': _cfg(
527
+ hf_hub_id='timm/',
528
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
529
+ 'convnext_small.in12k_ft_in1k': _cfg(
530
+ hf_hub_id='timm/',
531
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
532
+
533
+ 'convnext_atto.d2_in1k': _cfg(
534
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
535
+ hf_hub_id='timm/',
536
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
537
+ 'convnext_atto_ols.a2_in1k': _cfg(
538
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
539
+ hf_hub_id='timm/',
540
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
541
+ 'convnext_femto.d1_in1k': _cfg(
542
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
543
+ hf_hub_id='timm/',
544
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
545
+ 'convnext_femto_ols.d1_in1k': _cfg(
546
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
547
+ hf_hub_id='timm/',
548
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
549
+ 'convnext_pico.d1_in1k': _cfg(
550
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
551
+ hf_hub_id='timm/',
552
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
553
+ 'convnext_pico_ols.d1_in1k': _cfg(
554
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
555
+ hf_hub_id='timm/',
556
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
557
+ 'convnext_nano.in12k_ft_in1k': _cfg(
558
+ hf_hub_id='timm/',
559
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
560
+ 'convnext_nano.d1h_in1k': _cfg(
561
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
562
+ hf_hub_id='timm/',
563
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
564
+ 'convnext_nano_ols.d1h_in1k': _cfg(
565
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
566
+ hf_hub_id='timm/',
567
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
568
+ 'convnext_tiny_hnf.a2h_in1k': _cfg(
569
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
570
+ hf_hub_id='timm/',
571
+ crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
572
+
573
+ 'convnext_tiny.in12k_ft_in1k_384': _cfg(
574
+ hf_hub_id='timm/',
575
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
576
+ 'convnext_small.in12k_ft_in1k_384': _cfg(
577
+ hf_hub_id='timm/',
578
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
579
+
580
+ 'convnext_nano.in12k': _cfg(
581
+ hf_hub_id='timm/',
582
+ crop_pct=0.95, num_classes=11821),
583
+ 'convnext_tiny.in12k': _cfg(
584
+ hf_hub_id='timm/',
585
+ crop_pct=0.95, num_classes=11821),
586
+ 'convnext_small.in12k': _cfg(
587
+ hf_hub_id='timm/',
588
+ crop_pct=0.95, num_classes=11821),
589
+
590
+ 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
591
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
592
+ hf_hub_id='timm/',
593
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
594
+ 'convnext_small.fb_in22k_ft_in1k': _cfg(
595
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
596
+ hf_hub_id='timm/',
597
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
598
+ 'convnext_base.fb_in22k_ft_in1k': _cfg(
599
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
600
+ hf_hub_id='timm/',
601
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
602
+ 'convnext_large.fb_in22k_ft_in1k': _cfg(
603
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
604
+ hf_hub_id='timm/',
605
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
606
+ 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
607
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
608
+ hf_hub_id='timm/',
609
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
610
+
611
+ 'convnext_tiny.fb_in1k': _cfg(
612
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
613
+ hf_hub_id='timm/',
614
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
615
+ 'convnext_small.fb_in1k': _cfg(
616
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
617
+ hf_hub_id='timm/',
618
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
619
+ 'convnext_base.fb_in1k': _cfg(
620
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
621
+ hf_hub_id='timm/',
622
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
623
+ 'convnext_large.fb_in1k': _cfg(
624
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
625
+ hf_hub_id='timm/',
626
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
627
+
628
+ 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
629
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
630
+ hf_hub_id='timm/',
631
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
632
+ 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
633
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
634
+ hf_hub_id='timm/',
635
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
636
+ 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
637
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
638
+ hf_hub_id='timm/',
639
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
640
+ 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
641
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
642
+ hf_hub_id='timm/',
643
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
644
+ 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
645
+ url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
646
+ hf_hub_id='timm/',
647
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
648
+
649
+ 'convnext_tiny.fb_in22k': _cfg(
650
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
651
+ hf_hub_id='timm/',
652
+ num_classes=21841),
653
+ 'convnext_small.fb_in22k': _cfg(
654
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
655
+ hf_hub_id='timm/',
656
+ num_classes=21841),
657
+ 'convnext_base.fb_in22k': _cfg(
658
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
659
+ hf_hub_id='timm/',
660
+ num_classes=21841),
661
+ 'convnext_large.fb_in22k': _cfg(
662
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
663
+ hf_hub_id='timm/',
664
+ num_classes=21841),
665
+ 'convnext_xlarge.fb_in22k': _cfg(
666
+ url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
667
+ hf_hub_id='timm/',
668
+ num_classes=21841),
669
+
670
+ 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
671
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
672
+ hf_hub_id='timm/',
673
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
674
+ 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
675
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
676
+ hf_hub_id='timm/',
677
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
678
+ 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
679
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
680
+ hf_hub_id='timm/',
681
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
682
+ 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
683
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
684
+ hf_hub_id='timm/',
685
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
686
+ 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
687
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
688
+ hf_hub_id='timm/',
689
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
690
+ 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
691
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
692
+ hf_hub_id='timm/',
693
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
694
+ 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
695
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
696
+ hf_hub_id='timm/',
697
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
698
+ 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
699
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
700
+ hf_hub_id='timm/',
701
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
702
+ 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
703
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
704
+ hf_hub_id='timm/',
705
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
706
+ 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
707
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
708
+ hf_hub_id='timm/',
709
+ input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
710
+
711
+ 'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
712
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
713
+ hf_hub_id='timm/',
714
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
715
+ 'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
716
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
717
+ hf_hub_id='timm/',
718
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
719
+ 'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
720
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
721
+ hf_hub_id='timm/',
722
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
723
+ 'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
724
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
725
+ hf_hub_id='timm/',
726
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
727
+ 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
728
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
729
+ hf_hub_id='timm/',
730
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
731
+ 'convnextv2_base.fcmae_ft_in1k': _cfgv2(
732
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
733
+ hf_hub_id='timm/',
734
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
735
+ 'convnextv2_large.fcmae_ft_in1k': _cfgv2(
736
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
737
+ hf_hub_id='timm/',
738
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
739
+ 'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
740
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
741
+ hf_hub_id='timm/',
742
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
743
+
744
+ 'convnextv2_atto.fcmae': _cfgv2(
745
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
746
+ hf_hub_id='timm/',
747
+ num_classes=0),
748
+ 'convnextv2_femto.fcmae': _cfgv2(
749
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
750
+ hf_hub_id='timm/',
751
+ num_classes=0),
752
+ 'convnextv2_pico.fcmae': _cfgv2(
753
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
754
+ hf_hub_id='timm/',
755
+ num_classes=0),
756
+ 'convnextv2_nano.fcmae': _cfgv2(
757
+ url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
758
+ hf_hub_id='timm/',
759
+ num_classes=0),
760
+ 'convnextv2_tiny.fcmae': _cfgv2(
761
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
762
+ hf_hub_id='timm/',
763
+ num_classes=0),
764
+ 'convnextv2_base.fcmae': _cfgv2(
765
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
766
+ hf_hub_id='timm/',
767
+ num_classes=0),
768
+ 'convnextv2_large.fcmae': _cfgv2(
769
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
770
+ hf_hub_id='timm/',
771
+ num_classes=0),
772
+ 'convnextv2_huge.fcmae': _cfgv2(
773
+ url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
774
+ hf_hub_id='timm/',
775
+ num_classes=0),
776
+
777
+ 'convnextv2_small.untrained': _cfg(),
778
+
779
+ # CLIP weights, fine-tuned on in1k or in12k + in1k
780
+ 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
781
+ hf_hub_id='timm/',
782
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
783
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
784
+ 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
785
+ hf_hub_id='timm/',
786
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
787
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
788
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
789
+ hf_hub_id='timm/',
790
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
791
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
792
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
793
+ hf_hub_id='timm/',
794
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
795
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
796
+
797
+ 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
798
+ hf_hub_id='timm/',
799
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
800
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
801
+ 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
802
+ hf_hub_id='timm/',
803
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
804
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
805
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
806
+ hf_hub_id='timm/',
807
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
808
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
809
+ ),
810
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
811
+ hf_hub_id='timm/',
812
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
813
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
814
+ ),
815
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
816
+ hf_hub_id='timm/',
817
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
818
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
819
+
820
+ 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
821
+ hf_hub_id='timm/',
822
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
823
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
824
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
825
+ hf_hub_id='timm/',
826
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
827
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
828
+ 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
829
+ hf_hub_id='timm/',
830
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
831
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
832
+ 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
833
+ hf_hub_id='timm/',
834
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
835
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
836
+ 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
837
+ hf_hub_id='timm/',
838
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
839
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
840
+
841
+ # CLIP original image tower weights
842
+ 'convnext_base.clip_laion2b': _cfg(
843
+ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
844
+ hf_hub_filename='open_clip_pytorch_model.bin',
845
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
846
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
847
+ 'convnext_base.clip_laion2b_augreg': _cfg(
848
+ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
849
+ hf_hub_filename='open_clip_pytorch_model.bin',
850
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
851
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
852
+ 'convnext_base.clip_laiona': _cfg(
853
+ hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
854
+ hf_hub_filename='open_clip_pytorch_model.bin',
855
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
856
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
857
+ 'convnext_base.clip_laiona_320': _cfg(
858
+ hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
859
+ hf_hub_filename='open_clip_pytorch_model.bin',
860
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
861
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
862
+ 'convnext_base.clip_laiona_augreg_320': _cfg(
863
+ hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
864
+ hf_hub_filename='open_clip_pytorch_model.bin',
865
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
866
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
867
+ 'convnext_large_mlp.clip_laion2b_augreg': _cfg(
868
+ hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
869
+ hf_hub_filename='open_clip_pytorch_model.bin',
870
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
871
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
872
+ 'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
873
+ hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
874
+ hf_hub_filename='open_clip_pytorch_model.bin',
875
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
876
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
877
+ 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
878
+ hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
879
+ hf_hub_filename='open_clip_pytorch_model.bin',
880
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
881
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
882
+ 'convnext_xxlarge.clip_laion2b_soup': _cfg(
883
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
884
+ hf_hub_filename='open_clip_pytorch_model.bin',
885
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
886
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
887
+ 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
888
+ hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
889
+ hf_hub_filename='open_clip_pytorch_model.bin',
890
+ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
891
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
892
+ })
893
+
894
+
895
+ # @register_model
896
+ # def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
897
+ # # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
898
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
899
+ # model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
900
+ # return model
901
+
902
+
903
+ # @register_model
904
+ # def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
905
+ # # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
906
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
907
+ # model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
908
+ # return model
909
+
910
+
911
+ # @register_model
912
+ # def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
913
+ # # timm femto variant
914
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
915
+ # model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
916
+ # return model
917
+
918
+
919
+ # @register_model
920
+ # def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
921
+ # # timm femto variant
922
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
923
+ # model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
924
+ # return model
925
+
926
+
927
+ # @register_model
928
+ # def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
929
+ # # timm pico variant
930
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
931
+ # model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
932
+ # return model
933
+
934
+
935
+ # @register_model
936
+ # def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
937
+ # # timm nano variant with overlapping 3x3 conv stem
938
+ # model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
939
+ # model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
940
+ # return model
941
+
942
+
943
+ # @register_model
944
+ # def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
945
+ # # timm nano variant with standard stem and head
946
+ # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
947
+ # model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
948
+ # return model
949
+
950
+
951
+ # @register_model
952
+ # def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
953
+ # # experimental nano variant with overlapping conv stem
954
+ # model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
955
+ # model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
956
+ # return model
957
+
958
+
959
+ # @register_model
960
+ # def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
961
+ # # experimental tiny variant with norm before pooling in head (head norm first)
962
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
963
+ # model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
964
+ # return model
965
+
966
+
967
+ # @register_model
968
+ # def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
969
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
970
+ # model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
971
+ # return model
972
+
973
+
974
+ # @register_model
975
+ # def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
976
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
977
+ # model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
978
+ # return model
979
+
980
+ # @register_model
981
+ # def convnext_base_clip(pretrained='', **kwargs) -> ConvNeXt:
982
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
983
+ # model = _create_convnext(pretrained, pretrained=True, **dict(model_args, **kwargs))
984
+ # return model
985
+
986
+ # @register_model
987
+ # def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
988
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
989
+ # model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
990
+ # return model
991
+
992
+
993
+ # @register_model
994
+ # def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
995
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
996
+ # model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
997
+ # return model
998
+
999
+
1000
+ # @register_model
1001
+ # def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
1002
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
1003
+ # model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
1004
+ # return model
1005
+
1006
+
1007
+ # @register_model
1008
+ # def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
1009
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
1010
+ # model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
1011
+ # return model
1012
+
1013
+
1014
+ # @register_model
1015
+ def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
1016
+ model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
1017
+ # model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
1018
+ model = _create_convnext('convnext_xxlarge', pretrained=False, **dict(model_args, **kwargs))
1019
+ return model
1020
+
1021
+
1022
+ # @register_model
1023
+ # def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
1024
+ # # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1025
+ # model_args = dict(
1026
+ # depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
1027
+ # model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
1028
+ # return model
1029
+
1030
+
1031
+ # @register_model
1032
+ # def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
1033
+ # # timm femto variant
1034
+ # model_args = dict(
1035
+ # depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
1036
+ # model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
1037
+ # return model
1038
+
1039
+
1040
+ # @register_model
1041
+ # def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
1042
+ # # timm pico variant
1043
+ # model_args = dict(
1044
+ # depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
1045
+ # model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
1046
+ # return model
1047
+
1048
+
1049
+ # @register_model
1050
+ # def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
1051
+ # # timm nano variant with standard stem and head
1052
+ # model_args = dict(
1053
+ # depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
1054
+ # model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
1055
+ # return model
1056
+
1057
+
1058
+ # @register_model
1059
+ # def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
1060
+ # model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
1061
+ # model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
1062
+ # return model
1063
+
1064
+
1065
+ # @register_model
1066
+ # def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
1067
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
1068
+ # model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
1069
+ # return model
1070
+
1071
+
1072
+ # @register_model
1073
+ # def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
1074
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
1075
+ # model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
1076
+ # return model
1077
+
1078
+
1079
+ # @register_model
1080
+ # def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
1081
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
1082
+ # model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
1083
+ # return model
1084
+
1085
+
1086
+ # @register_model
1087
+ # def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
1088
+ # model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
1089
+ # model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
1090
+ # return model
1091
+
1092
+
1093
+ # register_model_deprecations(__name__, {
1094
+ # 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
1095
+ # 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
1096
+ # 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
1097
+ # 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
1098
+ # 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
1099
+ # 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
1100
+ # 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
1101
+ # 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
1102
+ # 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
1103
+ # 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
1104
+ # 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
1105
+ # 'convnext_small_in22k': 'convnext_small.fb_in22k',
1106
+ # 'convnext_base_in22k': 'convnext_base.fb_in22k',
1107
+ # 'convnext_large_in22k': 'convnext_large.fb_in22k',
1108
+ # 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
1109
+ # })
llava/model/multimodal_projector/builder.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on https://github.com/haotian-liu/LLaVA.
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import re
6
+
7
+ class IdentityMap(nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ def forward(self, x, *args, **kwargs):
12
+ return x
13
+
14
+ @property
15
+ def config(self):
16
+ return {"mm_projector_type": 'identity'}
17
+
18
+
19
+ class SimpleResBlock(nn.Module):
20
+ def __init__(self, channels):
21
+ super().__init__()
22
+ self.pre_norm = nn.LayerNorm(channels)
23
+
24
+ self.proj = nn.Sequential(
25
+ nn.Linear(channels, channels),
26
+ nn.GELU(),
27
+ nn.Linear(channels, channels)
28
+ )
29
+ def forward(self, x):
30
+ x = self.pre_norm(x)
31
+ return x + self.proj(x)
32
+
33
+
34
+ def build_vision_projector(config, delay_load=False, fpn_input_dim=[], **kwargs):
35
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
36
+ # if getattr(config, 'mm_use_4_vision_tokens', False):
37
+ # mm_hidden_size = config.mm_hidden_size * 4
38
+ # else:
39
+ mm_hidden_size = config.mm_hidden_size
40
+ if projector_type == 'linear':
41
+ return nn.Linear(mm_hidden_size, config.hidden_size)
42
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
43
+ if mlp_gelu_match:
44
+ mlp_depth = int(mlp_gelu_match.group(1))
45
+ modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
46
+ for _ in range(1, mlp_depth):
47
+ modules.append(nn.GELU())
48
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
49
+ return nn.Sequential(*modules)
50
+ if projector_type == 'identity':
51
+ return IdentityMap()
52
+
53
+ raise ValueError(f'Unknown projector type: {projector_type}')
llava/train/llava_trainer.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from torch.utils.data import Sampler
8
+
9
+ from transformers import Trainer
10
+ from transformers.trainer import (
11
+ is_sagemaker_mp_enabled,
12
+ get_parameter_names,
13
+ has_length,
14
+ ALL_LAYERNORM_LAYERS,
15
+ logger,
16
+ )
17
+ from typing import List, Optional
18
+
19
+
20
+ def maybe_zero_3(param, ignore_status=False, name=None):
21
+ from deepspeed import zero
22
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
23
+ if hasattr(param, "ds_id"):
24
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
25
+ if not ignore_status:
26
+ print(name, 'no ignore status')
27
+ with zero.GatheredParameters([param]):
28
+ param = param.data.detach().cpu().clone()
29
+ else:
30
+ param = param.detach().cpu().clone()
31
+ return param
32
+
33
+
34
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
35
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
36
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
37
+ return to_return
38
+
39
+
40
+ def split_to_even_chunks(indices, lengths, num_chunks):
41
+ """
42
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
43
+ """
44
+
45
+ if len(indices) % num_chunks != 0:
46
+ return [indices[i::num_chunks] for i in range(num_chunks)]
47
+
48
+ num_indices_per_chunk = len(indices) // num_chunks
49
+
50
+ chunks = [[] for _ in range(num_chunks)]
51
+ chunks_lengths = [0 for _ in range(num_chunks)]
52
+ for index in indices:
53
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
54
+ chunks[shortest_chunk].append(index)
55
+ chunks_lengths[shortest_chunk] += lengths[index]
56
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
57
+ chunks_lengths[shortest_chunk] = float("inf")
58
+
59
+ return chunks
60
+
61
+
62
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
63
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
64
+ assert all(l != 0 for l in lengths), "Should not have zero length."
65
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
66
+ # all samples are in the same modality
67
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
68
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
69
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
70
+
71
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
72
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
73
+ megabatch_size = world_size * batch_size
74
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
75
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
76
+
77
+ last_mm = mm_megabatches[-1]
78
+ last_lang = lang_megabatches[-1]
79
+ additional_batch = last_mm + last_lang
80
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
81
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
82
+ megabatches = [megabatches[i] for i in megabatch_indices]
83
+
84
+ if len(additional_batch) > 0:
85
+ megabatches.append(sorted(additional_batch))
86
+
87
+ return [i for megabatch in megabatches for i in megabatch]
88
+
89
+
90
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
91
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
92
+ indices = torch.randperm(len(lengths), generator=generator)
93
+ megabatch_size = world_size * batch_size
94
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
95
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
96
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
97
+
98
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
99
+
100
+
101
+ class LengthGroupedSampler(Sampler):
102
+ r"""
103
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
104
+ keeping a bit of randomness.
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ batch_size: int,
110
+ world_size: int,
111
+ lengths: Optional[List[int]] = None,
112
+ generator=None,
113
+ group_by_modality: bool = False,
114
+ ):
115
+ if lengths is None:
116
+ raise ValueError("Lengths must be provided.")
117
+
118
+ self.batch_size = batch_size
119
+ self.world_size = world_size
120
+ self.lengths = lengths
121
+ self.generator = generator
122
+ self.group_by_modality = group_by_modality
123
+
124
+ def __len__(self):
125
+ return len(self.lengths)
126
+
127
+ def __iter__(self):
128
+ if self.group_by_modality:
129
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
130
+ else:
131
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
132
+ return iter(indices)
133
+
134
+
135
+ class LlavaTrainer(Trainer):
136
+
137
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
138
+ if self.train_dataset is None or not has_length(self.train_dataset):
139
+ return None
140
+
141
+ if self.args.group_by_modality_length:
142
+ lengths = self.train_dataset.modality_lengths
143
+ return LengthGroupedSampler(
144
+ self.args.train_batch_size,
145
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
146
+ lengths=lengths,
147
+ group_by_modality=True,
148
+ )
149
+ else:
150
+ return super()._get_train_sampler()
151
+
152
+ def create_optimizer(self):
153
+ """
154
+ Setup the optimizer.
155
+
156
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
157
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
158
+ """
159
+ if is_sagemaker_mp_enabled():
160
+ return super().create_optimizer()
161
+
162
+ opt_model = self.model
163
+
164
+ if self.optimizer is None:
165
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
166
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
167
+ if self.args.mm_projector_lr is not None:
168
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
169
+ optimizer_grouped_parameters = [
170
+ {
171
+ "params": [
172
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
173
+ ],
174
+ "weight_decay": self.args.weight_decay,
175
+ },
176
+ {
177
+ "params": [
178
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
179
+ ],
180
+ "weight_decay": 0.0,
181
+ },
182
+ {
183
+ "params": [
184
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
185
+ ],
186
+ "weight_decay": self.args.weight_decay,
187
+ "lr": self.args.mm_projector_lr,
188
+ },
189
+ {
190
+ "params": [
191
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
192
+ ],
193
+ "weight_decay": 0.0,
194
+ "lr": self.args.mm_projector_lr,
195
+ },
196
+ ]
197
+
198
+ elif self.args.cross_attention_layer_lr:
199
+ cross_attn_parameters = [name for name, _ in opt_model.named_parameters() if "cross_attn_" in name]
200
+ optimizer_grouped_parameters = [
201
+ {
202
+ "params": [
203
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in cross_attn_parameters and p.requires_grad)
204
+ ],
205
+ "weight_decay": self.args.weight_decay,
206
+ },
207
+ {
208
+ "params": [
209
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in cross_attn_parameters and p.requires_grad)
210
+ ],
211
+ "weight_decay": 0.0,
212
+ },
213
+ {
214
+ "params": [
215
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in cross_attn_parameters and p.requires_grad)
216
+ ],
217
+ "weight_decay": self.args.weight_decay,
218
+ "lr": self.args.cross_attention_layer_lr,
219
+ },
220
+ {
221
+ "params": [
222
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in cross_attn_parameters and p.requires_grad)
223
+ ],
224
+ "weight_decay": 0.0,
225
+ "lr": self.args.cross_attention_layer_lr,
226
+ },
227
+ ]
228
+
229
+ elif self.args.vision_tower_lr:
230
+ vision_tower_parameters = [name for name, _ in opt_model.named_parameters() if "vision_tower" in name]
231
+ optimizer_grouped_parameters = [
232
+ {
233
+ "params": [
234
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in vision_tower_parameters and p.requires_grad)
235
+ ],
236
+ "weight_decay": self.args.weight_decay,
237
+ },
238
+ {
239
+ "params": [
240
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in vision_tower_parameters and p.requires_grad)
241
+ ],
242
+ "weight_decay": 0.0,
243
+ },
244
+ {
245
+ "params": [
246
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in vision_tower_parameters and p.requires_grad)
247
+ ],
248
+ "weight_decay": self.args.weight_decay,
249
+ "lr": self.args.vision_tower_lr,
250
+ },
251
+ {
252
+ "params": [
253
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in vision_tower_parameters and p.requires_grad)
254
+ ],
255
+ "weight_decay": 0.0,
256
+ "lr": self.args.vision_tower_lr,
257
+ },
258
+ ]
259
+
260
+ else:
261
+ optimizer_grouped_parameters = [
262
+ {
263
+ "params": [
264
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
265
+ ],
266
+ "weight_decay": self.args.weight_decay,
267
+ },
268
+ {
269
+ "params": [
270
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
271
+ ],
272
+ "weight_decay": 0.0,
273
+ },
274
+ ]
275
+
276
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
277
+
278
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
279
+ if optimizer_cls.__name__ == "Adam8bit":
280
+ import bitsandbytes
281
+
282
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
283
+
284
+ skipped = 0
285
+ for module in opt_model.modules():
286
+ if isinstance(module, nn.Embedding):
287
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
288
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
289
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
290
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
291
+ logger.info(f"skipped: {skipped/2**20}M params")
292
+
293
+ return self.optimizer
294
+
295
+ # Min: save all the model parameters even during pretraining
296
+ # def _save_checkpoint(self, model, trial, metrics=None):
297
+ # if getattr(self.args, 'tune_mm_mlp_adapter', False):
298
+ # from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
299
+ # checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
300
+
301
+ # run_dir = self._get_output_dir(trial=trial)
302
+ # output_dir = os.path.join(run_dir, checkpoint_folder)
303
+
304
+ # # Only save Adapter
305
+ # keys_to_match = ['mm_projector', 'vision_resampler']
306
+ # if getattr(self.args, "use_im_start_end", False):
307
+ # keys_to_match.extend(['embed_tokens', 'embed_in'])
308
+
309
+ # weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
310
+
311
+ # if self.args.local_rank == 0 or self.args.local_rank == -1:
312
+ # self.model.config.save_pretrained(output_dir)
313
+ # torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
314
+ # else:
315
+ # super(EagleTrainer, self)._save_checkpoint(model, trial, metrics)
316
+
317
+ # def _save(self, output_dir: Optional[str] = None, state_dict=None):
318
+ # if getattr(self.args, 'tune_mm_mlp_adapter', False):
319
+ # pass
320
+ # else:
321
+ # super(EagleTrainer, self)._save(output_dir, state_dict)
llava/utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
2
+
3
+ import datetime
4
+ import time
5
+ import logging
6
+ import logging.handlers
7
+ import os
8
+ import sys
9
+
10
+ import requests
11
+ import torch
12
+ import transformers
13
+ from transformers.integrations import is_deepspeed_zero3_enabled
14
+
15
+ from llava.constants import LOGDIR
16
+
17
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
18
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
19
+
20
+ handler = None
21
+
22
+
23
+ def build_logger(logger_name, logger_filename):
24
+ global handler
25
+
26
+ formatter = logging.Formatter(
27
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ # Set the format of root handlers
32
+ if not logging.getLogger().handlers:
33
+ logging.basicConfig(level=logging.INFO)
34
+ logging.getLogger().handlers[0].setFormatter(formatter)
35
+
36
+ # Redirect stdout and stderr to loggers
37
+ stdout_logger = logging.getLogger("stdout")
38
+ stdout_logger.setLevel(logging.INFO)
39
+ sl = StreamToLogger(stdout_logger, logging.INFO)
40
+ sys.stdout = sl
41
+
42
+ stderr_logger = logging.getLogger("stderr")
43
+ stderr_logger.setLevel(logging.ERROR)
44
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
45
+ sys.stderr = sl
46
+
47
+ # Get logger
48
+ logger = logging.getLogger(logger_name)
49
+ logger.setLevel(logging.INFO)
50
+
51
+ # Add a file handler for all loggers
52
+ if handler is None:
53
+ os.makedirs(LOGDIR, exist_ok=True)
54
+ filename = os.path.join(LOGDIR, logger_filename)
55
+ handler = logging.handlers.TimedRotatingFileHandler(
56
+ filename, when='D', utc=True, encoding='UTF-8')
57
+ handler.setFormatter(formatter)
58
+
59
+ for name, item in logging.root.manager.loggerDict.items():
60
+ if isinstance(item, logging.Logger):
61
+ item.addHandler(handler)
62
+
63
+ return logger
64
+
65
+
66
+ class StreamToLogger(object):
67
+ """
68
+ Fake file-like stream object that redirects writes to a logger instance.
69
+ """
70
+ def __init__(self, logger, log_level=logging.INFO):
71
+ self.terminal = sys.stdout
72
+ self.logger = logger
73
+ self.log_level = log_level
74
+ self.linebuf = ''
75
+
76
+ def __getattr__(self, attr):
77
+ return getattr(self.terminal, attr)
78
+
79
+ def write(self, buf):
80
+ temp_linebuf = self.linebuf + buf
81
+ self.linebuf = ''
82
+ for line in temp_linebuf.splitlines(True):
83
+ # From the io.TextIOWrapper docs:
84
+ # On output, if newline is None, any '\n' characters written
85
+ # are translated to the system default line separator.
86
+ # By default sys.stdout.write() expects '\n' newlines and then
87
+ # translates them so this is still cross platform.
88
+ if line[-1] == '\n':
89
+ self.logger.log(self.log_level, line.rstrip())
90
+ else:
91
+ self.linebuf += line
92
+
93
+ def flush(self):
94
+ if self.linebuf != '':
95
+ self.logger.log(self.log_level, self.linebuf.rstrip())
96
+ self.linebuf = ''
97
+
98
+
99
+ def disable_torch_init():
100
+ """
101
+ Disable the redundant torch default initialization to accelerate model creation.
102
+ """
103
+ import torch
104
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
105
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
106
+
107
+
108
+ def violates_moderation(text):
109
+ """
110
+ Check whether the text violates OpenAI moderation API.
111
+ """
112
+ url = "https://api.openai.com/v1/moderations"
113
+ headers = {"Content-Type": "application/json",
114
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
115
+ text = text.replace("\n", "")
116
+ data = "{" + '"input": ' + f'"{text}"' + "}"
117
+ data = data.encode("utf-8")
118
+ try:
119
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
120
+ flagged = ret.json()["results"][0]["flagged"]
121
+ except requests.exceptions.RequestException as e:
122
+ flagged = False
123
+ except KeyError as e:
124
+ flagged = False
125
+
126
+ return flagged
127
+
128
+
129
+ def pretty_print_semaphore(semaphore):
130
+ if semaphore is None:
131
+ return "None"
132
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
133
+
134
+
135
+
136
+ @torch.no_grad()
137
+ def load_state_dict_into_model(model_to_load, state_dict, start_prefix=""):
138
+ # copied and altered from:
139
+ # https://github.com/huggingface/transformers/blob/9d35edbb30625489bf286a9b15aed0c5a3119c1c/src/transformers/modeling_utils.py#L650
140
+ # https://github.com/baaivision/EVA/blob/2ca37a8c0d82b9496754f3fa9c3966b4caa54d75/EVA-CLIP-18B/shinji/eva_clip/factory.py#L168
141
+
142
+ # copy state_dict so _load_from_state_dict can modify it
143
+ metadata = getattr(state_dict, "_metadata", None)
144
+ state_dict = state_dict.copy()
145
+ if metadata is not None:
146
+ state_dict._metadata = metadata
147
+ error_msgs = []
148
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
149
+ # so we need to apply the function recursively.
150
+ def load(module: torch.nn.Module, prefix=""):
151
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
152
+ args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
153
+ # Parameters of module and children will start with prefix. We can exit early if there are none in this state_dict
154
+ if is_deepspeed_zero3_enabled():
155
+ import deepspeed
156
+ with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
157
+ if torch.distributed.get_rank() == 0:
158
+ module._load_from_state_dict(*args)
159
+ else:
160
+ module._load_from_state_dict(*args)
161
+ for name, child in module._modules.items():
162
+ if child is not None:
163
+ load(child, prefix + name + ".")
164
+
165
+ load(model_to_load, prefix=start_prefix)
166
+ # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
167
+ # it's safe to delete it.
168
+ del state_dict
169
+ return error_msgs
170
+
171
+
172
+ class Timer:
173
+ def __init__(self):
174
+ self.start_time = None
175
+ self.elapsed_time = 0
176
+
177
+ def start(self):
178
+ self.start_time = time.time()
179
+
180
+ def reset(self):
181
+ self.start_time = None
182
+ self.elapsed_time = 0
183
+
184
+ def get_elapsed_time(self):
185
+ if self.start_time is not None:
186
+ return self.elapsed_time + (time.time() - self.start_time)
187
+
188
+
189
+ class TimeoutTerminateCallback(transformers.TrainerCallback):
190
+ def __init__(self, args, total_time_limit=240, pre_terminate_time=10):
191
+ self.training_args = args
192
+ self.total_time_limit = total_time_limit
193
+ self.pre_terminate_time = pre_terminate_time
194
+ self.timer = Timer()
195
+ self.timer.start()
196
+
197
+ if args.local_rank == 0:
198
+ print(f"Timer for terminate callback has been set.\nTotal limit: {total_time_limit}min\nPre terminate time: {pre_terminate_time}min")
199
+
200
+ self.time_to_kill = (total_time_limit - pre_terminate_time) * 60
201
+
202
+
203
+ def on_step_end(self, args, state, control, model, **kwargs):
204
+ elapsed_time = self.timer.get_elapsed_time()
205
+
206
+ if elapsed_time > self.time_to_kill:
207
+ if args.local_rank == 0:
208
+ print("Timeout, start to save checkpoint....")
209
+ control.should_save = True
210
+ control.should_training_stop = True
211
+
212
+ return control