lixin4ever Lillyr commited on
Commit
44d8da2
·
verified ·
1 Parent(s): f643e0e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +14 -0
  3. app.py +562 -0
  4. demo/.DS_Store +0 -0
  5. demo/images/1.jpg +3 -0
  6. demo/images/2.jpg +3 -0
  7. demo/images/3.jpg +3 -0
  8. demo/images/4.jpg +3 -0
  9. demo/images/5.jpg +3 -0
  10. demo/images/6.jpg +3 -0
  11. demo/images/7.jpg +3 -0
  12. demo/images/8.jpg +3 -0
  13. demo/images/LICENSE +3 -0
  14. demo/videos/1.mp4 +3 -0
  15. demo/videos/2.mp4 +3 -0
  16. demo/videos/3.mp4 +3 -0
  17. demo/videos/4.mp4 +3 -0
  18. requirements.txt +48 -0
  19. videollama3/.DS_Store +0 -0
  20. videollama3/__init__.py +239 -0
  21. videollama3/constants.py +46 -0
  22. videollama3/infer.py +82 -0
  23. videollama3/mm_utils.py +704 -0
  24. videollama3/model/__init__.py +166 -0
  25. videollama3/model/__pycache__/__init__.cpython-310.pyc +0 -0
  26. videollama3/model/__pycache__/encoder.cpython-310.pyc +0 -0
  27. videollama3/model/__pycache__/processor.cpython-310.pyc +0 -0
  28. videollama3/model/__pycache__/projector.cpython-310.pyc +0 -0
  29. videollama3/model/__pycache__/region_encoder.cpython-310.pyc +0 -0
  30. videollama3/model/__pycache__/videollama3_arch.cpython-310.pyc +0 -0
  31. videollama3/model/__pycache__/videollama3_qwen2.cpython-310.pyc +0 -0
  32. videollama3/model/damovl_encoder/__init__.py +3 -0
  33. videollama3/model/damovl_encoder/__pycache__/__init__.cpython-310.pyc +0 -0
  34. videollama3/model/damovl_encoder/__pycache__/configuration_damovl_encoder.cpython-310.pyc +0 -0
  35. videollama3/model/damovl_encoder/__pycache__/image_processing.cpython-310.pyc +0 -0
  36. videollama3/model/damovl_encoder/__pycache__/modeling_damovl_encoder.cpython-310.pyc +0 -0
  37. videollama3/model/damovl_encoder/configuration_damovl_encoder.py +71 -0
  38. videollama3/model/damovl_encoder/image_processing.py +472 -0
  39. videollama3/model/damovl_encoder/modeling_damovl_encoder.py +542 -0
  40. videollama3/model/encoder.py +385 -0
  41. videollama3/model/processor.py +366 -0
  42. videollama3/model/projector.py +160 -0
  43. videollama3/model/qwen2vl_encoder/__init__.py +3 -0
  44. videollama3/model/qwen2vl_encoder/__pycache__/__init__.cpython-310.pyc +0 -0
  45. videollama3/model/qwen2vl_encoder/__pycache__/configuration_qwen2vl_encoder.cpython-310.pyc +0 -0
  46. videollama3/model/qwen2vl_encoder/__pycache__/image_processing.cpython-310.pyc +0 -0
  47. videollama3/model/qwen2vl_encoder/__pycache__/modeling_qwen2vl_encoder.cpython-310.pyc +0 -0
  48. videollama3/model/qwen2vl_encoder/configuration_qwen2vl_encoder.py +72 -0
  49. videollama3/model/qwen2vl_encoder/image_processing.py +469 -0
  50. videollama3/model/qwen2vl_encoder/modeling_qwen2vl_encoder.py +367 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo/videos/ filter=lfs diff=lfs merge=lfs -text
37
+ demo/videos/3.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ demo/videos/4.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ demo/videos/1.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ demo/videos/2.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ demo/images/4.jpg filter=lfs diff=lfs merge=lfs -text
42
+ demo/images/5.jpg filter=lfs diff=lfs merge=lfs -text
43
+ demo/images/6.jpg filter=lfs diff=lfs merge=lfs -text
44
+ demo/images/8.jpg filter=lfs diff=lfs merge=lfs -text
45
+ demo/images/1.jpg filter=lfs diff=lfs merge=lfs -text
46
+ demo/images/2.jpg filter=lfs diff=lfs merge=lfs -text
47
+ demo/images/3.jpg filter=lfs diff=lfs merge=lfs -text
48
+ demo/images/7.jpg filter=lfs diff=lfs merge=lfs -text
49
+ demo/images/LICENSE filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from transformers import SamModel, SamProcessor
5
+ from PIL import Image
6
+ import os
7
+ import cv2
8
+ import argparse
9
+ import sys
10
+ # This is for making model initialization faster and has no effect since we are loading the weights
11
+ sys.path.append('./')
12
+ from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
13
+ from videollama3.mm_utils import load_images
14
+ from videollama3.mm_utils import load_video
15
+
16
+
17
+ color_rgb = (1.0, 1.0, 1.0)
18
+ color_rgbs = [
19
+ (1.0, 1.0, 1.0),
20
+ (1.0, 0.0, 0.0),
21
+ (0.0, 1.0, 1.0),
22
+ (0.0, 1.0, 0.0),
23
+ (0.0, 0.0, 1.0),
24
+ (1.0, 0.0, 1.0),
25
+ ]
26
+
27
+ mask_list = []
28
+ mask_raw_list = []
29
+ mask_list_video = []
30
+ mask_raw_list_video = []
31
+
32
+ def extract_first_frame_from_video(video):
33
+ cap = cv2.VideoCapture(video)
34
+ success, frame = cap.read()
35
+ cap.release()
36
+ if success:
37
+ return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
38
+ return None
39
+
40
+ def extract_points_from_mask(mask_pil):
41
+ mask = np.asarray(mask_pil)[..., 0]
42
+ coords = np.nonzero(mask)
43
+ coords = np.stack((coords[1], coords[0]), axis=1)
44
+
45
+ return coords
46
+
47
+ def add_contour(img, mask, color=(1., 1., 1.)):
48
+ img = img.copy()
49
+
50
+ mask = mask.astype(np.uint8) * 255
51
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
52
+ cv2.drawContours(img, contours, -1, color, thickness=8)
53
+
54
+ return img
55
+
56
+ def generate_masks(image):
57
+ global mask_list
58
+ global mask_raw_list
59
+ image['image'] = image['background'].convert('RGB')
60
+ # del image['background'], image['composite']
61
+ assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
62
+
63
+ mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
64
+ points = extract_points_from_mask(mask)
65
+ np.random.seed(0)
66
+ if points.shape[0] == 0:
67
+ raise gr.Error("No points selected")
68
+
69
+ points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
70
+ points = points[points_selected_indices]
71
+ coords = [points.tolist()]
72
+ mask_np = apply_sam(image['image'], coords)
73
+
74
+ mask_raw_list.append(mask_np)
75
+ mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
76
+
77
+ mask_list.append((mask_image, f"<region{len(mask_list)}>"))
78
+ # Return a list containing the mask image.
79
+ image['layers'] = []
80
+ image['composite'] = image['background']
81
+ return mask_list, image
82
+
83
+
84
+ def generate_masks_video(image):
85
+ global mask_list_video
86
+ global mask_raw_list_video
87
+ image['image'] = image['background'].convert('RGB')
88
+ # del image['background'], image['composite']
89
+ assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
90
+
91
+ mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
92
+ points = extract_points_from_mask(mask)
93
+ np.random.seed(0)
94
+ if points.shape[0] == 0:
95
+ raise gr.Error("No points selected")
96
+
97
+ points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
98
+ points = points[points_selected_indices]
99
+ coords = [points.tolist()]
100
+ mask_np = apply_sam(image['image'], coords)
101
+
102
+ mask_raw_list_video.append(mask_np)
103
+ mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
104
+
105
+ mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
106
+ # Return a list containing the mask image.
107
+ image['layers'] = []
108
+ image['composite'] = image['background']
109
+ return mask_list_video, image
110
+
111
+
112
+
113
+ def describe(image, mode, query, masks):
114
+ # Create an image object from the uploaded image
115
+ # print(image.keys())
116
+
117
+ image['image'] = image['background'].convert('RGB')
118
+ # del image['background'], image['composite']
119
+ assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
120
+
121
+ # Handle both hex and rgba color formats
122
+
123
+ img_np = np.asarray(image['image']).astype(float) / 255.
124
+ if mode=='Caption':
125
+ mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
126
+
127
+ points = extract_points_from_mask(mask)
128
+
129
+ np.random.seed(0)
130
+
131
+ if points.shape[0] == 0:
132
+ if len(masks)>1:
133
+ raise gr.Error("No points selected")
134
+
135
+ else:
136
+ # Randomly sample 8 points from the mask
137
+ # Follow DAM https://github.com/NVlabs/describe-anything
138
+ points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
139
+ points = points[points_selected_indices]
140
+
141
+ coords = [points.tolist()]
142
+
143
+ mask_np = apply_sam(image['image'], coords)
144
+
145
+ masks = []
146
+ masks.append(mask_np)
147
+ mask_ids = [0]
148
+
149
+ img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
150
+ img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
151
+ else:
152
+ masks = mask_raw_list
153
+ img_with_contour_np = img_np.copy()
154
+
155
+ mask_ids = []
156
+ for i, mask_np in enumerate(masks):
157
+ img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
158
+ img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
159
+ mask_ids.append(0)
160
+
161
+ masks = np.stack(masks, axis=0)
162
+ masks = torch.from_numpy(masks).to(torch.uint8)
163
+
164
+
165
+
166
+ img = np.asarray(image['image'])
167
+
168
+
169
+ if mode == "Caption":
170
+ query = '<image>\nPlease describe the <region> in the image in detail.'
171
+ else:
172
+ if len(masks)==1:
173
+ prefix = "<image>\nThere is 1 region in the image: <region0> <region>. "
174
+ else:
175
+ prefix = f"<image>\nThere is {len(masks)} region in the image: "
176
+ for i in range(len(masks)):
177
+ prefix += f"<region{i}><region>, "
178
+ prefix = prefix[:-2]+'. '
179
+ query = prefix + query
180
+ # print(query)
181
+
182
+ image['layers'] = []
183
+ image['composite'] = image['background']
184
+
185
+ text = ""
186
+ yield img_with_contour_pil, text, image
187
+
188
+ for token in get_model_output(
189
+ [img],
190
+ query,
191
+ model=model,
192
+ tokenizer=tokenizer,
193
+ masks=masks,
194
+ mask_ids=mask_ids,
195
+ modal='image',
196
+ image_downsampling=1,
197
+ streaming=True,
198
+ ):
199
+ text += token
200
+ yield gr.update(), text, gr.update()
201
+
202
+
203
+ def load_first_frame(video_path):
204
+ cap = cv2.VideoCapture(video_path)
205
+ ret, frame = cap.read()
206
+ cap.release()
207
+ if not ret:
208
+ raise gr.Error("Could not read the video file.")
209
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
210
+ image = Image.fromarray(frame)
211
+ return image
212
+
213
+ def describe_video(video_path, mode, query, annotated_frame, masks):
214
+ global mask_list_video
215
+ # Create a temporary directory to save extracted video frames
216
+ cap = cv2.VideoCapture(video_path)
217
+
218
+ video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0])
219
+
220
+ annotated_frame['image'] = annotated_frame['background'].convert('RGB')
221
+
222
+ # Process the annotated frame from the image editor
223
+ if isinstance(annotated_frame, dict):
224
+ # Get the composite image with annotations
225
+ frame_img = annotated_frame.get("image", annotated_frame.get("background"))
226
+ if frame_img is None:
227
+ raise gr.Error("No valid annotation found in the image editor.")
228
+ frame_img = frame_img.convert("RGB")
229
+
230
+ # Get the annotation layer
231
+ if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0:
232
+ mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB")
233
+ else:
234
+ mask = Image.new("RGB", frame_img.size, 0)
235
+ else:
236
+ frame_img = annotated_frame.convert("RGB")
237
+ mask = Image.new("RGB", frame_img.size, 0)
238
+
239
+ img_np = np.asarray(annotated_frame['image']).astype(float) / 255.
240
+ # Extract points from the annotated mask (using the first channel)
241
+ if mode == "Caption":
242
+ points = extract_points_from_mask(mask)
243
+ np.random.seed(0)
244
+ if points.shape[0] == 0:
245
+ raise gr.Error("No points were selected in the annotation.")
246
+ # Randomly select up to 8 points
247
+ # Follow DAM https://github.com/NVlabs/describe-anything
248
+ points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
249
+ points = points[points_selected_indices]
250
+
251
+ # print(f"Selected points (to SAM): {points}")
252
+
253
+ coords = [points.tolist()]
254
+
255
+ mask_np = apply_sam(annotated_frame['image'], coords)
256
+
257
+ masks = []
258
+ masks.append(mask_np)
259
+ mask_ids = [0]
260
+
261
+ # img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
262
+ # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
263
+
264
+
265
+ else:
266
+ masks = mask_raw_list_video
267
+ img_with_contour_np = img_np.copy()
268
+
269
+ mask_ids = []
270
+ for i, mask_np in enumerate(masks):
271
+ # img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
272
+ # img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
273
+ mask_ids.append(0)
274
+
275
+
276
+
277
+ masks = np.stack(masks, axis=0)
278
+ masks = torch.from_numpy(masks).to(torch.uint8)
279
+
280
+
281
+
282
+
283
+ if mode == "Caption":
284
+ query = '<video>\nPlease describe the <region> in the video in detail.'
285
+ else:
286
+ if len(masks)==1:
287
+ prefix = "<video>\nThere is 1 object in the video: <object0> <region>. "
288
+ else:
289
+ prefix = f"<video>\nThere is {len(masks)} objects in the video: "
290
+ for i in range(len(masks)):
291
+ prefix += f"<object{i}><region>, "
292
+ prefix = prefix[:-2]+'. '
293
+ query = prefix + query
294
+
295
+ # Initialize empty text
296
+ # text = description_generator
297
+ annotated_frame['layers'] = []
298
+ annotated_frame['composite'] = annotated_frame['background']
299
+
300
+ if mode=="Caption":
301
+ mask_list_video = []
302
+ mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
303
+ mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
304
+ text = ""
305
+ yield frame_img, text, mask_list_video
306
+
307
+ for token in get_model_output(
308
+ video_tensor,
309
+ query,
310
+ model=model,
311
+ tokenizer=tokenizer,
312
+ masks=masks,
313
+ mask_ids=mask_ids,
314
+ modal='video',
315
+ streaming=True,
316
+ ):
317
+ text += token
318
+ yield gr.update(), text, gr.update()
319
+
320
+
321
+
322
+ def apply_sam(image, input_points):
323
+ inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
324
+
325
+ with torch.no_grad():
326
+ outputs = sam_model(**inputs)
327
+
328
+ masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
329
+ scores = outputs.iou_scores[0, 0]
330
+
331
+ mask_selection_index = scores.argmax()
332
+
333
+ mask_np = masks[mask_selection_index].numpy()
334
+
335
+ return mask_np
336
+
337
+ def clear_masks():
338
+ global mask_list
339
+ global mask_raw_list
340
+ mask_list = []
341
+ mask_raw_list = []
342
+ return []
343
+
344
+
345
+ def clear_masks_video():
346
+ global mask_list_video
347
+ global mask_raw_list_video
348
+ mask_list_video = []
349
+ mask_raw_list_video = []
350
+ return []
351
+
352
+
353
+ if __name__ == "__main__":
354
+ parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
355
+ parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoRefer-VideoLLaMA3-7B", help="Path to the model checkpoint")
356
+ parser.add_argument("--prompt-mode", type=str, default="focal_prompt", help="Prompt mode")
357
+ parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode")
358
+ parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
359
+ parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
360
+
361
+ args_cli = parser.parse_args()
362
+ print(args_cli.model_path)
363
+
364
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
365
+
366
+ HEADER = ("""
367
+ <div>
368
+ <h1>VideoRefer X VideoLLaMA3 Demo</h1>
369
+ <h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
370
+ <h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
371
+ </div>
372
+ </div>
373
+ <div style="display: flex; justify-content: left; margin-top: 10px;">
374
+ <a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
375
+ <a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
376
+ <a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
377
+ </div>
378
+ """)
379
+
380
+ with gr.Row():
381
+ with gr.Column():
382
+ gr.HTML(HEADER)
383
+
384
+
385
+ image_tips = """
386
+ ### 💡 Tips:
387
+
388
+ 🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
389
+
390
+ 🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
391
+
392
+ 🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
393
+
394
+ 📌 Click the button 'Clear Masks' to clear the current generated masks.
395
+
396
+ """
397
+
398
+ video_tips = """
399
+ ### 💡 Tips:
400
+ ⚠️ For video mode, we only support masking on the first frame in this demo.
401
+
402
+ 🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
403
+
404
+ 🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
405
+
406
+ 🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
407
+
408
+ 📌 Click the button 'Clear Masks' to clear the current generated masks.
409
+
410
+ """
411
+
412
+
413
+ with gr.TabItem("Image"):
414
+ with gr.Row():
415
+ with gr.Column():
416
+ image_input = gr.ImageEditor(
417
+ label="Image",
418
+ type="pil",
419
+ sources=['upload'],
420
+ brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
421
+ eraser=True,
422
+ layers=False,
423
+ transforms=[],
424
+ height=300,
425
+ )
426
+ generate_mask_btn = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
427
+ mode = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
428
+ query = gr.Textbox(label="Question", value="What is the relationship between <region0> and <region1>?", interactive=True, visible=False)
429
+
430
+ submit_btn = gr.Button("Generate Caption", variant="primary")
431
+ submit_btn1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
432
+ gr.Examples([f"./demo/images/{i+1}.jpg" for i in range(8)], inputs=image_input, label="Examples")
433
+
434
+ with gr.Column():
435
+ mask_output = gr.Gallery(label="Referred Masks", object_fit='scale-down', visible=False)
436
+ output_image = gr.Image(label="Image with Mask", visible=True, height=400)
437
+ description = gr.Textbox(label="Output", visible=True)
438
+
439
+ clear_masks_btn = gr.Button("Clear Masks", variant="secondary", visible=False)
440
+ gr.Markdown(image_tips)
441
+
442
+ with gr.TabItem("Video"):
443
+ with gr.Row():
444
+ with gr.Column():
445
+ video_input = gr.Video(label="Video")
446
+ # load_btn = gr.Button("🖼️ Load First Frame", variant="secondary")
447
+ first_frame = gr.ImageEditor(
448
+ label="Annotate First Frame",
449
+ type="pil",
450
+ sources=['upload'],
451
+ brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
452
+ eraser=True,
453
+ layers=False,
454
+ transforms=[],
455
+ height=300,
456
+ )
457
+ generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
458
+ gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
459
+
460
+ with gr.Column():
461
+ mode_video = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
462
+ mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
463
+
464
+ query_video = gr.Textbox(label="Question", value="What is the relationship between <object0> and <object1>?", interactive=True, visible=False)
465
+
466
+ submit_btn_video = gr.Button("Generate Caption", variant="primary")
467
+ submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
468
+ description_video = gr.Textbox(label="Output", visible=True)
469
+
470
+ clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
471
+
472
+ gr.Markdown(video_tips)
473
+
474
+
475
+ def toggle_query_and_generate_button(mode):
476
+ query_visible = mode == "QA"
477
+ caption_visible = mode == "Caption"
478
+ global mask_list
479
+ global mask_raw_list
480
+ mask_list = []
481
+ mask_raw_list = []
482
+ return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], ""
483
+
484
+ video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
485
+
486
+ mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description])
487
+
488
+ def toggle_query_and_generate_button_video(mode):
489
+ query_visible = mode == "QA"
490
+ caption_visible = mode == "Caption"
491
+ global mask_list_video
492
+ global mask_raw_list_video
493
+ mask_list_video = []
494
+ mask_raw_list_video = []
495
+ return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), []
496
+
497
+
498
+ mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video])
499
+
500
+ submit_btn.click(
501
+ fn=describe,
502
+ inputs=[image_input, mode, query],
503
+ outputs=[output_image, description, image_input],
504
+ api_name="describe"
505
+ )
506
+
507
+ submit_btn1.click(
508
+ fn=describe,
509
+ inputs=[image_input, mode, query],
510
+ outputs=[output_image, description, image_input],
511
+ api_name="describe"
512
+ )
513
+
514
+ generate_mask_btn.click(
515
+ fn=generate_masks,
516
+ inputs=[image_input],
517
+ outputs=[mask_output, image_input]
518
+ )
519
+
520
+ generate_mask_btn_video.click(
521
+ fn=generate_masks_video,
522
+ inputs=[first_frame],
523
+ outputs=[mask_output_video, first_frame]
524
+ )
525
+
526
+ clear_masks_btn.click(
527
+ fn=clear_masks,
528
+ outputs=[mask_output]
529
+ )
530
+
531
+ clear_masks_btn_video.click(
532
+ fn=clear_masks_video,
533
+ outputs=[mask_output_video]
534
+ )
535
+
536
+ submit_btn_video.click(
537
+ fn=describe_video,
538
+ inputs=[video_input, mode_video, query_video, first_frame],
539
+ outputs=[first_frame, description_video, mask_output_video],
540
+ api_name="describe_video"
541
+ )
542
+
543
+ submit_btn_video1.click(
544
+ fn=describe_video,
545
+ inputs=[video_input, mode_video, query_video, first_frame],
546
+ outputs=[first_frame, description_video, mask_output_video],
547
+ api_name="describe_video"
548
+ )
549
+
550
+
551
+
552
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
553
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
554
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
555
+
556
+ disable_torch_init()
557
+
558
+
559
+ model, processor, tokenizer = model_init(args_cli.model_path)
560
+
561
+
562
+ demo.launch()
demo/.DS_Store ADDED
Binary file (6.15 kB). View file
 
demo/images/1.jpg ADDED

Git LFS Details

  • SHA256: 57f222d08703255914ed6cbda7d0c5fd8b772d7b975f3ffd73ee47f24f7eaabe
  • Pointer size: 131 Bytes
  • Size of remote file: 491 kB
demo/images/2.jpg ADDED

Git LFS Details

  • SHA256: a9011049db02799c9bf68ba228445968a4dc2d097df8f3559c4e18a8a09a4f7f
  • Pointer size: 131 Bytes
  • Size of remote file: 501 kB
demo/images/3.jpg ADDED

Git LFS Details

  • SHA256: 5c5159bf7114d08967f95475176670043115b157bf700efa34190260cd917662
  • Pointer size: 132 Bytes
  • Size of remote file: 1.03 MB
demo/images/4.jpg ADDED

Git LFS Details

  • SHA256: 39174b4188bc6d928cf0153f0d3a3224e15c9823f8cdc99b4ad6627067741bb8
  • Pointer size: 131 Bytes
  • Size of remote file: 708 kB
demo/images/5.jpg ADDED

Git LFS Details

  • SHA256: e02c393a23aadd1304497e3a9b41144df166d1cfda33ea3e00eed94e27da3aa4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
demo/images/6.jpg ADDED

Git LFS Details

  • SHA256: 1d512c06daf1b5c7919fc351c496ff65d9cac601c57ae263433c49d90d3b083e
  • Pointer size: 132 Bytes
  • Size of remote file: 3.78 MB
demo/images/7.jpg ADDED

Git LFS Details

  • SHA256: 68d5970b974101b61b1bcf5dd790485a89f85a651641990c2629d4a56de40ba8
  • Pointer size: 132 Bytes
  • Size of remote file: 3.65 MB
demo/images/8.jpg ADDED

Git LFS Details

  • SHA256: bdb5acb53dfc78e74008d113b22f5a2fb1e2c7b33cb8eadf4983d709bfe366ba
  • Pointer size: 131 Bytes
  • Size of remote file: 335 kB
demo/images/LICENSE ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ac4c813c90895cdc79c71fdbd02715fd0c5505c24d95c5941747c904d6e93bc
3
+ size 149
demo/videos/1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad78d268f6f1ad9a457a7768665157f74c20292136cefbf6bfc2a07de940dd0a
3
+ size 804232
demo/videos/2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5eebbd330be490709c1b39cd1d82ae074f3fe275487bc6b77d2aa5cd74d40d05
3
+ size 1255466
demo/videos/3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:946550c741c9dc515340ab93b203614094632191db0d8f9697bd580f4a271947
3
+ size 8743247
demo/videos/4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b06b309812947b909ce7b8eaaea94a9ca60a8452a33e3109f5f6ffb1dbf8ee6
3
+ size 1334796
requirements.txt ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ --extra-index-url https://download.pytorch.org/whl/cu121
3
+ --extra-index-url https://download.pytorch.org/whl/cu118
4
+ # basic dependencies
5
+ torch==2.4.0
6
+ torchvision==0.19.0
7
+ datasets==2.21.0
8
+ transformers==4.46.3
9
+ tokenizers==0.20.3
10
+ deepspeed==0.15.4
11
+ accelerate==1.0.1
12
+ peft==0.4.0
13
+ timm==1.0.3
14
+ numpy==1.24.4
15
+ # data processing
16
+ decord==0.6.0
17
+ imageio==2.34.0
18
+ imageio-ffmpeg==0.4.9
19
+ moviepy==1.0.3
20
+ scenedetect==0.6.3
21
+ opencv-python==4.6.0.66
22
+ pyarrow
23
+ pysubs2
24
+ ffmpeg-python
25
+ # misc
26
+ scikit-learn==1.2.2
27
+ huggingface_hub==0.23.4
28
+ sentencepiece==0.1.99
29
+ shortuuid
30
+ einops==0.6.1
31
+ einops-exts==0.0.4
32
+ bitsandbytes==0.43.3 # for cuda 124
33
+ pydantic>=2.0
34
+ markdown2[all]
35
+ gradio==5.34.0
36
+ gradio_client==1.10.3
37
+ httpx==0.24.1
38
+ requests
39
+ openai
40
+ uvicorn
41
+ fastapi
42
+ tensorboard
43
+ wandb
44
+ tabulate
45
+ Levenshtein
46
+ pycocotools==2.0.8
47
+ spaces
48
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
videollama3/.DS_Store ADDED
Binary file (6.15 kB). View file
 
videollama3/__init__.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import copy
3
+ import math
4
+ import warnings
5
+ import shutil
6
+ from functools import partial
7
+
8
+ import torch
9
+
10
+ from .model import load_pretrained_model
11
+ from .model.processor import Videollama3Processor
12
+ from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, resize_image_mask
13
+ from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN
14
+ from videollama3.constants import REGION_TOKEN
15
+ from transformers import TextIteratorStreamer
16
+ from threading import Thread
17
+
18
+ def disable_torch_init():
19
+ """
20
+ Disable the redundant torch default initialization to accelerate model creation.
21
+ """
22
+ import torch
23
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
24
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
25
+
26
+
27
+ def model_init(model_path=None, **kwargs):
28
+ model_path = "DAMO-NLP-SG/VideoLLaMA2-7B" if model_path is None else model_path
29
+ model_name = get_model_name_from_path(model_path)
30
+ tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)
31
+
32
+ if tokenizer.pad_token is None and tokenizer.unk_token is not None:
33
+ tokenizer.pad_token = tokenizer.unk_token
34
+
35
+ aspect_ratio = model.config.image_aspect_ratio if hasattr(model.config, "image_aspect_ratio") else "pad"
36
+ image_size = model.config.image_size if hasattr(model.config, "image_size") else 384
37
+ # NOTE: If num_frames is None, the frame sampling mode is "fps". If num_frames is not None, the frame sampling mode is "uniform".
38
+ # num_frames = model.config.num_frames
39
+ model.config.region_token_index = tokenizer.convert_tokens_to_ids(REGION_TOKEN)
40
+ processor = {
41
+ 'image': load_images,
42
+ 'video': load_video,
43
+ 'text': None
44
+ }
45
+
46
+ return model, processor, tokenizer
47
+
48
+
49
+ def get_model_output(images_or_videos, instruct, model, tokenizer, modal='video', **kwargs):
50
+ streaming = kwargs.pop('streaming', False)
51
+ if streaming:
52
+ return mm_infer(images_or_videos, instruct, model, tokenizer, modal, streaming=True, **kwargs)
53
+ else:
54
+ output = mm_infer(images_or_videos, instruct, model, tokenizer, modal, streaming=False, **kwargs)
55
+ return next(output)
56
+
57
+
58
+ def mm_infer(images_or_videos, instruct, model, tokenizer, modal='video', **kwargs):
59
+ """inference api of VideoLLaMA2 for video understanding.
60
+
61
+ Args:
62
+ model: VideoLLaMA2 model.
63
+ images_or_videos (torch.Tensor): image tensor (1, C, H, W) / video tensor (T, C, H, W).
64
+ instruct (str): text instruction for understanding video.
65
+ tokenizer: tokenizer.
66
+ do_sample (bool): whether to sample.
67
+ modal (str): inference modality.
68
+ Returns:
69
+ str: response of the model.
70
+ """
71
+ mask_ids = kwargs.pop('mask_ids', None)
72
+ masks = kwargs.pop('masks', None)
73
+ streaming = kwargs.pop('streaming', False)
74
+ if modal == 'image':
75
+ modal_token = DEFAULT_IMAGE_TOKEN
76
+ images = images_or_videos
77
+ additional_frames = images.copy()
78
+ timestamps = None
79
+ elif modal == 'video':
80
+ modal_token = DEFAULT_VIDEO_TOKEN
81
+ images, timestamps, additional_frames = images_or_videos
82
+ elif modal == 'text':
83
+ modal_token = ''
84
+ else:
85
+ raise ValueError(f"Unsupported modal: {modal}")
86
+
87
+ vlprocessor = Videollama3Processor(model.get_vision_encoder().image_processor, tokenizer)
88
+ vlprocessor.tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN, STREAM_START_TOKEN, STREAM_END_TOKEN], special_tokens=True)
89
+
90
+ model.config.image_token_index = vlprocessor.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
91
+
92
+ if masks is not None:
93
+ additional_frames, masks, mask_nums = resize_image_mask(additional_frames, masks, mask_ids)
94
+
95
+ for idx in range(len(mask_nums)):
96
+ instruct = instruct.replace('<region>', "["+REGION_TOKEN*mask_nums[idx]+"]", 1)
97
+
98
+
99
+ additional_images_dict = vlprocessor._process_image(additional_frames, image_downsampling=1)
100
+ additional_images = additional_images_dict['images']
101
+ # import pdb
102
+ # pdb.set_trace()
103
+
104
+
105
+ # flatten_patches1 = additional_images[0].reshape(26, 46, 3, -1)
106
+ # from matplotlib import pyplot as plt
107
+ # plt.imshow(flatten_patches1[:,:,:,0])
108
+ # plt.savefig('16.png')
109
+
110
+ additional_images_thws = additional_images_dict['grid_thws']
111
+ additional_images = (additional_images, additional_images_thws)
112
+
113
+ else:
114
+ additional_images = None
115
+
116
+
117
+ # 1. text preprocess (tag process & generate prompt).
118
+ if isinstance(instruct, str):
119
+ messages = [{'role': 'user', 'content': instruct}]
120
+ elif isinstance(instruct, list):
121
+ messages = copy.deepcopy(instruct)
122
+ else:
123
+ raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
124
+
125
+ if all(not modal_token in message["content"] for message in messages):
126
+ warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
127
+ messages[0]["content"] = modal_token + messages[0]["content"]
128
+
129
+ converted_messages = []
130
+ for message in messages:
131
+ chunks = message["content"].split(modal_token)
132
+ converted_messages.append({
133
+ "role": "user",
134
+ "content": []
135
+ })
136
+
137
+ for chunk_idx in range(1, 2 * len(chunks)):
138
+ if chunk_idx % 2 == 1:
139
+ chunk = chunks[chunk_idx // 2].strip()
140
+ converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
141
+ else:
142
+ if modal == 'image':
143
+ converted_messages[-1]["content"].append({"type": "image"})
144
+ elif modal == 'video':
145
+ converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
146
+
147
+ messages = converted_messages
148
+
149
+ # 2. vision preprocess (load & transform image or video).
150
+ if model.config.model_type in ['videollama3_mistral', 'videollama3_mixtral']:
151
+ system_message = [
152
+ {'role': 'system', 'content': (
153
+ """<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."""
154
+ """\n"""
155
+ """If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>""")
156
+ }
157
+ ]
158
+ else:
159
+ system_message = []
160
+
161
+ image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
162
+ # TODO: attention mask?
163
+ messages = system_message + messages
164
+ data_dict = vlprocessor(
165
+ images=images,
166
+ text=messages,
167
+ image_downsampling=image_downsampling,
168
+ return_tensors="pt",
169
+ )
170
+
171
+ torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
172
+
173
+ images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]]
174
+ grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]]
175
+
176
+ # 3. generate response according to visual signals and prompts.
177
+ keywords = [tokenizer.eos_token]
178
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"])
179
+ stop_str = tokenizer.eos_token
180
+
181
+ do_sample = kwargs.get('do_sample', False)
182
+ temperature = kwargs.get('temperature', 0.2 if do_sample else 0.0)
183
+ top_p = kwargs.get('top_p', 0.9)
184
+ max_new_tokens = kwargs.get('max_new_tokens', 2048)
185
+ if not streaming:
186
+ with torch.inference_mode():
187
+ output_ids = model.generate(
188
+ # input_ids,
189
+ # attention_mask=attention_masks,
190
+ # images=images,
191
+ data_dict["input_ids"].cuda(),
192
+ attention_mask=data_dict["attention_mask"].cuda(),
193
+ images=[(modal, images, grid_thws)],
194
+ do_sample=do_sample,
195
+ temperature=temperature,
196
+ max_new_tokens=max_new_tokens,
197
+ top_p=top_p,
198
+ use_cache=True,
199
+ stopping_criteria=[stopping_criteria],
200
+ pad_token_id=tokenizer.eos_token_id,
201
+ additional_images=[additional_images],
202
+ masks=[masks],
203
+ )
204
+
205
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
206
+
207
+ yield outputs
208
+
209
+ else:
210
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
211
+ generation_kwargs = dict(
212
+ inputs=data_dict["input_ids"].cuda(),
213
+ attention_mask=data_dict["attention_mask"].cuda(),
214
+ images=[(modal, images, grid_thws)],
215
+ do_sample=do_sample,
216
+ temperature=temperature,
217
+ max_new_tokens=max_new_tokens,
218
+ top_p=top_p,
219
+ use_cache=True,
220
+ stopping_criteria=[stopping_criteria],
221
+ pad_token_id=tokenizer.eos_token_id,
222
+ additional_images=[additional_images],
223
+ masks=[masks],
224
+ streamer=streamer
225
+ )
226
+
227
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
228
+ thread.start()
229
+
230
+ generated_text = ""
231
+ for new_text in streamer:
232
+ generated_text += new_text
233
+ if stop_str in generated_text:
234
+ generated_text = generated_text[:generated_text.find(stop_str)]
235
+ break
236
+ yield new_text
237
+
238
+ thread.join()
239
+
videollama3/constants.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+
9
+ # Image arguments
10
+ IMAGE_TOKEN_INDEX = -200
11
+ DEFAULT_IMAGE_TOKEN = "<image>"
12
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
13
+ DEFAULT_IM_START_TOKEN = "<im_start>"
14
+ DEFAULT_IM_END_TOKEN = "<im_end>"
15
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
16
+
17
+ # Video arguments
18
+ VIDEO_TOKEN_INDEX = -201
19
+ DEFAULT_VIDEO_TOKEN = "<video>"
20
+ NUM_FRAMES = 128
21
+ MAX_FRAMES = 768
22
+ NUM_FRAMES_PER_SECOND = 1
23
+
24
+ # Region arguments
25
+ REGION_TOKEN = "<REGION>"
26
+
27
+ # Audio arguments
28
+ AUDIO_TOKEN_INDEX = -202
29
+ DEFAULT_AUDIO_TOKEN = "<audio>"
30
+
31
+ # Stream arguments
32
+ STREAM_START_TOKEN = "<|stream_start|>"
33
+ STREAM_END_TOKEN = "<|stream_end|>"
34
+ STREAM_IMAGE_TOKEN = "<stream_image>"
35
+ STREAM_FPS = 2
36
+ STREAM_IMAGE_SIZE = 224
37
+ STREAM_DOWNSAMPLING = 4
38
+ STREAM_MAX_FRAMES = 400
39
+
40
+ MODAL_INDEX_MAP = {
41
+ "<image>": -200,
42
+ "<video>": -201,
43
+ "<audio>": -202,
44
+ }
45
+
46
+ subimage_token_num=196
videollama3/infer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "1"
3
+
4
+ import os
5
+ import torch
6
+ import sys
7
+ sys.path.append('./')
8
+ from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
9
+ from videollama3.mm_utils import load_video
10
+
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+ def infer_image(model, tokenizer):
15
+ image_path = 'demo/images/1.jpg'
16
+ image = Image.open(image_path)
17
+ image_data = np.array(image)
18
+
19
+ question = '<image>\nPlease describe the <region> in the image in detail.'
20
+
21
+ mask = np.load('demo/masks/demo0.npy')
22
+ masks = []
23
+ masks.append(mask)
24
+ masks = np.array(masks)
25
+ masks = torch.from_numpy(masks).to(torch.uint8)
26
+
27
+ mask_ids = [0]*len(masks)
28
+
29
+ output = get_model_output(
30
+ [image_data],
31
+ question,
32
+ model=model,
33
+ tokenizer=tokenizer,
34
+ masks=masks,
35
+ mask_ids=mask_ids,
36
+ modal='image',
37
+ image_downsampling=1,
38
+ )
39
+ print(output)
40
+
41
+ def infer_video(model, tokenizer):
42
+ video_path = 'demo/videos/1.mp4'
43
+ question = '<video>\nPlease describe the <region> in the video in detail.'
44
+
45
+ frame_idx = 0 # mask from the first frame
46
+ video_tensor = load_video(video_path, fps=1, max_frames=768, frame_ids=[frame_idx])
47
+
48
+ mask = np.load('demo/masks/demo1.npy')
49
+ masks = []
50
+ masks.append(mask)
51
+ masks = np.array(masks)
52
+ masks = torch.from_numpy(masks).to(torch.uint8)
53
+
54
+ mask_ids = [0]*len(masks)
55
+
56
+ output = get_model_output(
57
+ video_tensor,
58
+ question,
59
+ model=model,
60
+ tokenizer=tokenizer,
61
+ masks=masks,
62
+ mask_ids=mask_ids,
63
+ modal='video',
64
+ )
65
+ print(output)
66
+
67
+ def main():
68
+ disable_torch_init()
69
+
70
+ # fill in the model path here
71
+ model_path = '/mnt/workspace/workgroup/yuanyq/code/videollama3/ProjectX_region/work_dirs/VideoRefer-VideoLLaMA3-7B'
72
+ model, processor, tokenizer = model_init(model_path)
73
+
74
+ # image
75
+ infer_image(model, tokenizer)
76
+
77
+ # viideo
78
+ infer_video(model, tokenizer)
79
+
80
+
81
+ if __name__=='__main__':
82
+ main()
videollama3/mm_utils.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import re
4
+ import math
5
+ import base64
6
+ import traceback
7
+ from io import BytesIO
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torchvision.transforms.functional as VF
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from transformers import StoppingCriteria
15
+
16
+ import cv2
17
+ import imageio
18
+ import ffmpeg
19
+ from PIL import Image
20
+ from decord import VideoReader, cpu
21
+
22
+ from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
23
+ from pycocotools import mask as maskUtils
24
+
25
+ def resize_image_mask(images, masks, mask_ids, patch_size=14):
26
+ resize_images = []
27
+ resize_masks = []
28
+ mask_nums = []
29
+ for i, mask in enumerate(masks):
30
+ image = images[mask_ids[i]]
31
+ h, w = image.shape[:2]
32
+ if mask.sum()==0:
33
+ print('mask is none...')
34
+ mask = torch.ones((h, w))
35
+ rows, cols = np.where(mask == 1)
36
+
37
+ min_row, max_row = rows.min(), rows.max()
38
+ min_col, max_col = cols.min(), cols.max()
39
+
40
+ bbox = (max(0,min_row-patch_size*2), max(0,min_col-patch_size*2), min(h-1, max_row+patch_size*2), min(w-1, max_col+patch_size*2))
41
+ mask_h = bbox[2] - bbox[0]
42
+ mask_w = bbox[3] - bbox[1]
43
+ cropping_img = image[bbox[0]: bbox[2], bbox[1]: bbox[3], :]
44
+ cropping_mask = mask[bbox[0]: bbox[2], bbox[1]: bbox[3]]
45
+
46
+ scale_rate = math.ceil(math.sqrt(1960/mask.sum()))
47
+ if scale_rate==1:
48
+ if (mask.sum()/196)>100:
49
+ scale_rate = math.sqrt((mask.sum()/196)/100)
50
+ scale_rate = 1/scale_rate
51
+ resize_h = math.ceil((mask_h*scale_rate)/patch_size) * patch_size
52
+ resize_w = math.ceil((mask_w*scale_rate)/patch_size) * patch_size
53
+
54
+ resize_img = cv2.resize(cropping_img, (resize_w, resize_h))
55
+ resize_mask = F.interpolate(cropping_mask[None, None], size=(resize_h//patch_size, resize_w//patch_size), mode='bilinear', align_corners=False)[0,0]
56
+ mask_nums.append(min(10, int(resize_mask.sum())))
57
+
58
+ resize_images.append(resize_img)
59
+ resize_masks.append(resize_mask)
60
+
61
+ return resize_images, resize_masks, mask_nums
62
+
63
+ def reshape_images_to_raw_grid(mm_features_raw, grid_thws):
64
+ start_idx=0
65
+ reshaped_features = []
66
+ for thw_group in grid_thws:
67
+ for tensor_thw in thw_group:
68
+ _, H, W = tensor_thw.squeeze().tolist()
69
+ num_elements = H * W
70
+
71
+ split_tensor = mm_features_raw[start_idx:start_idx + num_elements].view(H, W, -1)
72
+ reshaped_features.append(split_tensor)
73
+
74
+ start_idx += num_elements
75
+ assert len(mm_features_raw)==start_idx
76
+ return reshaped_features
77
+
78
+ def annToMask(mask_ann, h=None, w=None):
79
+ if isinstance(mask_ann, list):
80
+ rles = maskUtils.frPyObjects(mask_ann, h, w)
81
+ rle = maskUtils.merge(rles)
82
+ elif isinstance(mask_ann['counts'], list):
83
+ # uncompressed RLE
84
+ rle = maskUtils.frPyObjects(mask_ann, h, w)
85
+ else:
86
+ # rle
87
+ rle = mask_ann
88
+ mask = maskUtils.decode(rle)
89
+ return mask
90
+
91
+ def chunk_list(input_list, chunk_size):
92
+ return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
93
+
94
+
95
+ def load_image_from_base64(image):
96
+ return Image.open(BytesIO(base64.b64decode(image)))
97
+
98
+
99
+ def expand2square(pil_img, background_color):
100
+ width, height = pil_img.size
101
+ if width == height:
102
+ return pil_img
103
+ elif width > height:
104
+ result = Image.new(pil_img.mode, (width, width), background_color)
105
+ result.paste(pil_img, (0, (width - height) // 2))
106
+ return result
107
+ else:
108
+ result = Image.new(pil_img.mode, (height, height), background_color)
109
+ result.paste(pil_img, ((height - width) // 2, 0))
110
+ return result
111
+
112
+
113
+ def grid_divide(image, cell_size):
114
+ """
115
+ Divides an image into grid of a specified size.
116
+
117
+ Args:
118
+ image (PIL.Image.Image): The input image.
119
+ cell_size (int): The size of each cell.
120
+
121
+ Returns:
122
+ list: A list of PIL.Image.Image objects representing the patches.
123
+ """
124
+ grid = []
125
+ width, height = image.size
126
+ for i in range(0, height, cell_size):
127
+ row = []
128
+ for j in range(0, width, cell_size):
129
+ box = (j, i, j + cell_size, i + cell_size)
130
+ row.append(image.crop(box))
131
+ grid.append(row)
132
+
133
+ return grid
134
+
135
+
136
+ def load_images(image_path):
137
+ if isinstance(image_path, str) and os.path.isfile(image_path):
138
+ images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
139
+ # images = [Image.open(image_path).convert('RGB')]
140
+ elif isinstance(image_path, str) and os.path.isdir(image_path):
141
+ images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
142
+ # images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
143
+ elif isinstance(image_path, list) and isinstance(image_path[0], str):
144
+ images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
145
+ # images = [Image.open(f).convert('RGB') for f in image_path]
146
+ elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
147
+ images = image_path
148
+ elif isinstance(image_path, Image.Image):
149
+ images = [image_path]
150
+ else:
151
+ print('image_path: ', image_path)
152
+ raise ValueError(f"Unsupported image path type: {image_path}")
153
+
154
+ return images
155
+
156
+
157
+ def process_pad_image(image, padding_value=(0, 0, 0)):
158
+ image = expand2square(image, padding_value)
159
+
160
+ return [image]
161
+
162
+
163
+ def find_closest_aspect_ratio(src_ratio, tgt_ratios, ori_size, tgt_size):
164
+ best_ratio_diff = float('inf')
165
+ best_ratio = (1, 1)
166
+ area = ori_size[0] * ori_size[1]
167
+ for ratio in tgt_ratios:
168
+ tgt_ratio = ratio[0] / ratio[1]
169
+ ratio_diff = abs(src_ratio - tgt_ratio)
170
+ if ratio_diff < best_ratio_diff:
171
+ best_ratio_diff = ratio_diff
172
+ best_ratio = ratio
173
+ elif ratio_diff == best_ratio_diff:
174
+ if area > 0.5 * tgt_size[0] * tgt_size[1] * ratio[0] * ratio[1]:
175
+ best_ratio = ratio
176
+
177
+ return best_ratio
178
+
179
+
180
+ def process_dynamic_image(image, image_size=384, use_thumbnail=True):
181
+ # Grid Params:
182
+ min_num = 1
183
+ max_num = 12
184
+
185
+ if isinstance(image_size, int):
186
+ image_size = (image_size, image_size)
187
+
188
+ ori_size = image.size
189
+ aspect_ratio = ori_size[0] / ori_size[1]
190
+
191
+ # calculate the existing image aspect ratio
192
+ tgt_ratios = []
193
+ for n in range(min_num, max_num + 1):
194
+ tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
195
+ tgt_ratios = set(tgt_ratios)
196
+ tgt_ratios = sorted(tgt_ratios, key=lambda x: x[0] * x[1])
197
+
198
+ # find the closest aspect ratio to the target
199
+ tgt_ratio = find_closest_aspect_ratio(aspect_ratio, tgt_ratios, ori_size, image_size)
200
+
201
+ # resize the image to the target size
202
+ tgt_width = image_size[0] * tgt_ratio[0]
203
+ tgt_height = image_size[1] * tgt_ratio[1]
204
+ resized_img = image.resize((tgt_width, tgt_height))
205
+
206
+ # NOTE: internvl2 style split the image into one column grids
207
+ # num_grids = tgt_ratio[0] * tgt_ratio[1]
208
+ # grid_images = []
209
+ # for i in range(num_grids):
210
+ # box = (
211
+ # (i % tgt_ratio[0]) * image_size[0],
212
+ # (i // tgt_ratio[0]) * image_size[1],
213
+ # (i % tgt_ratio[0] + 1) * image_size[0],
214
+ # (i // tgt_ratio[0] + 1) * image_size[1],
215
+ # )
216
+ # # crop out the grid image
217
+ # grid_images.append(resized_img.crop(box))
218
+ # assert len(grid_images) == num_grids
219
+ # grid_images = [grid_images]
220
+
221
+ # NOTE: eager implementation
222
+ # num_grids = tgt_ratio[0] * tgt_ratio[1]
223
+ # sub_grid_images = []
224
+ # tmp_grid_images = []
225
+ # for i in range(num_grids):
226
+ # box = (
227
+ # (i % tgt_ratio[0]) * image_size[0],
228
+ # (i // tgt_ratio[0]) * image_size[1],
229
+ # (i % tgt_ratio[0] + 1) * image_size[0],
230
+ # (i // tgt_ratio[0] + 1) * image_size[1],
231
+ # )
232
+ # tmp_grid_images.append(resized_img.crop(box))
233
+
234
+ # if (i + 1) % tgt_ratio[0] == 0:
235
+ # sub_grid_images.append(tmp_grid_images)
236
+ # tmp_grid_images = []
237
+
238
+ image_grid = grid_divide(resized_img, image_size[0])
239
+
240
+ if use_thumbnail:
241
+ thumbnail_img = image.resize((image_size[0], image_size[1]))
242
+ image_grid = [[thumbnail_img]] + image_grid
243
+
244
+ return image_grid
245
+
246
+
247
+ def process_highres_image(image_path, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
248
+ # Grid Params:
249
+ grid_width = [1, 2, 3]
250
+ grid_width_real = [x * image_size for x in grid_width]
251
+
252
+ longest_side = max(image.size)
253
+ fit_grid_width_real = [x for x in grid_width_real if x >= longest_side]
254
+ if len(fit_grid_width_real) == 0:
255
+ select_size = max(grid_width_real)
256
+ else:
257
+ select_size = min(fit_grid_width_real)
258
+
259
+ image_padded = expand2square(image, padding_value)
260
+ image_padded = image_padded.resize((select_size, select_size))
261
+ image_grid = grid_divide(image_padded, image_size)
262
+
263
+ if use_thumbnail:
264
+ thumbnail_img = image.resize((image_size, image_size))
265
+ image_grid = [[thumbnail_img]] + image_grid
266
+
267
+ return image_grid
268
+
269
+
270
+ def select_best_resolution(original_size, possible_resolutions):
271
+ """
272
+ Selects the best resolution from a list of possible resolutions based on the original size.
273
+
274
+ Args:
275
+ original_size (tuple): The original size of the image in the format (width, height).
276
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
277
+
278
+ Returns:
279
+ tuple: The best fit resolution in the format (width, height).
280
+ """
281
+ original_width, original_height = original_size
282
+ best_fit = None
283
+ max_effective_resolution = 0
284
+ min_wasted_resolution = float('inf')
285
+
286
+ for width, height in possible_resolutions:
287
+ scale = min(width / original_width, height / original_height)
288
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
289
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
290
+ wasted_resolution = (width * height) - effective_resolution
291
+
292
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
293
+ max_effective_resolution = effective_resolution
294
+ min_wasted_resolution = wasted_resolution
295
+ best_fit = (width, height)
296
+
297
+ return best_fit
298
+
299
+
300
+ def process_anyres_image(image, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
301
+ """
302
+ Process an image with variable resolutions.
303
+
304
+ Args:
305
+ image (PIL.Image.Image): The input image to be processed.
306
+ processor: The image processor object.
307
+
308
+ Returns:
309
+ torch.Tensor: A tensor containing the processed image patches.
310
+ """
311
+ # Grid Params:
312
+ possible_grids = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
313
+ possible_resolutions = [(x * image_size, y * image_size) for x, y in possible_grids]
314
+
315
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
316
+
317
+ # resize and padding image
318
+ nw, nh = best_resolution
319
+ ow, oh = image.size
320
+
321
+ scale_factor = min(nw / ow, nh / oh)
322
+ new_size = (int(ow * scale_factor), int(oh * scale_factor))
323
+
324
+ image_padded = Image.new("RGB", (nw, nh), padding_value)
325
+ image_padded.paste(image.resize(new_size), ((nw - new_size[0]) // 2, (nh - new_size[1]) // 2))
326
+
327
+ image_grid = grid_divide(image_padded, image_size)
328
+
329
+ if use_thumbnail:
330
+ thumbnail_img = image.resize((image_size, image_size))
331
+ image_grid = [[thumbnail_img]] + image_grid
332
+
333
+ return image_grid
334
+
335
+
336
+ def process_adares_image(image_path, image_size=384, use_thumbnail=True):
337
+ # Grid Params:
338
+ min_num = 1
339
+ max_num = 12
340
+
341
+ if isinstance(image_size, int):
342
+ image_size = (image_size, image_size)
343
+
344
+ ori_size = image.size
345
+ aspect_ratio = ori_size[0] / ori_size[1]
346
+
347
+ # calculate the existing image aspect ratio
348
+ tgt_ratios = []
349
+ for n in range(min_num, max_num + 1):
350
+ tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
351
+ tgt_ratios = set(tgt_ratios)
352
+ possible_resolutions = [(x * image_size[0], y * image_size[1]) for x, y in tgt_ratios]
353
+
354
+ # find the most possible resolution
355
+ best_resolution = select_best_resolution(ori_size, possible_resolutions)
356
+
357
+ # resize the image to the target size
358
+ resized_img = image.resize((best_resolution[0], best_resolution[1]))
359
+
360
+ image_grid = grid_divide(resized_img, image_size[0])
361
+
362
+ if use_thumbnail:
363
+ thumbnail_img = image.resize((image_size[0], image_size[1]))
364
+ image_grid = [[thumbnail_img]] + image_grid
365
+
366
+ return image_grid
367
+
368
+
369
+ def process_images(image_path, processor, aspect_ratio='pad', image_size=384, use_thumbnail=True):
370
+ images = load_images(image_path)
371
+
372
+ padding_value = tuple(int(x*255) for x in processor.image_mean)
373
+
374
+ image_grids = []
375
+ for image in images:
376
+ if aspect_ratio == 'pad':
377
+ image_grid = process_pad_image(image, padding_value=padding_value)
378
+ elif aspect_ratio == 'dynamic':
379
+ image_grid = process_dynamic_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
380
+ elif aspect_ratio == 'highres':
381
+ image_grid = process_highres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
382
+ elif aspect_ratio == 'anyres':
383
+ image_grid = process_anyres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
384
+ elif aspect_ratio == 'adares':
385
+ image_grid = process_adares_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
386
+ else:
387
+ image_grid = [image]
388
+
389
+ image_grid = [processor.preprocess(image_row, return_tensors='pt', num_images=len(images)) for image_row in image_grid]
390
+ image_grids.append(image_grid)
391
+
392
+ return image_grids
393
+
394
+
395
+ def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None):
396
+ if mode == 'uniform':
397
+ assert num_frames is not None, "Number of frames must be provided for uniform sampling."
398
+ if duration <= num_frames:
399
+ return np.arange(duration).astype(int)
400
+ # NOTE: v1 version
401
+ # Calculate the size of each segment from which a frame will be extracted
402
+ # if duration <= num_frames:
403
+ # return np.arange(duration).astype(int)
404
+ # seg_size = float(duration - 1) / num_frames
405
+
406
+ # frame_ids = []
407
+ # for i in range(num_frames):
408
+ # # Calculate the start and end indices of each segment
409
+ # start = seg_size * i
410
+ # end = seg_size * (i + 1)
411
+ # # Append the middle index of the segment to the list
412
+ # frame_ids.append((start + end) / 2)
413
+
414
+ # return np.round(np.array(frame_ids) + 1e-6).astype(int)
415
+ # NOTE: v0 version
416
+ return np.linspace(0, duration-1, num_frames, dtype=int)
417
+ elif mode == 'fps':
418
+ assert vid_fps is not None, "FPS must be provided for FPS sampling."
419
+ fps = fps if fps is not None else NUM_FRAMES_PER_SECOND
420
+ segment_len = min(vid_fps // fps, duration)
421
+ return np.arange(segment_len // 2, duration, segment_len, dtype=int)
422
+ else:
423
+ raise ImportError(f'Unsupported frame sampling mode: {mode}')
424
+
425
+
426
+ def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=None, temporal_factor=1, frame_ids=None):
427
+ if s is not None and e is not None:
428
+ s = s if s >= 0. else 0.
429
+ e = e if e >= 0. else 0.
430
+ if s > e:
431
+ s, e = e, s
432
+ elif s == e:
433
+ e = s + 1
434
+
435
+ # 1. Loading Video
436
+ if os.path.isdir(video_path):
437
+ frame_files = sorted(os.listdir(video_path))
438
+
439
+ vid_fps = 3
440
+ num_frames_of_video = len(frame_files)
441
+ elif video_path.endswith('.gif'):
442
+ gif_reader = imageio.get_reader(video_path)
443
+
444
+ vid_fps = 25
445
+ num_frames_of_video = len(gif_reader)
446
+ else:
447
+ vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
448
+ # vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
449
+
450
+ vid_fps = vreader.get_avg_fps()
451
+ num_frames_of_video = len(vreader)
452
+
453
+ # 2. Determine frame range & Calculate frame indices
454
+ f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
455
+ f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
456
+ frame_indices = list(range(f_start, f_end + 1))
457
+
458
+ duration = len(frame_indices)
459
+ # 3. Sampling frame indices
460
+ max_frames = max_frames if max_frames is not None else MAX_FRAMES
461
+ if fps is not None and duration / vid_fps < max_frames:
462
+ try:
463
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps)]
464
+ except:
465
+ print('sampled_frame_indices error: ', )
466
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
467
+
468
+ else:
469
+ sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
470
+
471
+ # 4. Acquire frame data
472
+ if os.path.isdir(video_path):
473
+ frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices]
474
+ elif video_path.endswith('.gif'):
475
+ frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
476
+ else:
477
+ frames = vreader.get_batch(sampled_frame_indices).asnumpy()
478
+
479
+ # frames = frames.transpose(0, 3, 1, 2)
480
+ timestamps = [x / vid_fps for x in sampled_frame_indices]
481
+
482
+ if temporal_factor > 1:
483
+ pad_length = temporal_factor - len(frames) % temporal_factor
484
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
485
+ [timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
486
+
487
+ # NOTE: pad the video with black frames
488
+ # while num_frames is not None and len(video_data) < num_frames:
489
+ # video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
490
+
491
+ additional_frames = []
492
+ if frame_ids is not None:
493
+ if os.path.isdir(video_path):
494
+ additional_frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in frame_ids]
495
+ elif video_path.endswith('.gif'):
496
+ additional_frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in frame_ids]
497
+ else:
498
+ additional_frames = vreader.get_batch(frame_ids).asnumpy()
499
+
500
+ return frames, timestamps, additional_frames
501
+
502
+
503
+ def load_video(
504
+ video_path: str,
505
+ start_time: Optional[float] = None,
506
+ end_time: Optional[float] = None,
507
+ fps: Optional[float] = None,
508
+ max_frames: Optional[float] = None,
509
+ size: Optional[int] = None,
510
+ size_divisible: int = 1,
511
+ precise_time: bool = False,
512
+ verbose: bool = False,
513
+ temporal_factor: int = 1,
514
+ frame_ids = None
515
+ ):
516
+ """
517
+ Load and process a video file and return the frames and the timestamps of each frame.
518
+
519
+ Args:
520
+ video_path (str): Path to the video file.
521
+ start_time (float, optional): Start time in seconds. Defaults to None.
522
+ end_time (float, optional): End time in seconds. Defaults to None.
523
+ fps (float, optional): Frames per second. Defaults to None.
524
+ num_frames (float, optional): Number of frames to sample. Defaults to None.
525
+ size (int, optional): Size of the shortest side. Defaults to None.
526
+ size_divisible (int, optional): Size divisible by this number. Defaults to 1.
527
+ precise_time (bool, optional): Whether to use precise time. Defaults to False.
528
+ verbose (bool, optional): Print ffmpeg output. Defaults to False.
529
+
530
+ Returns:
531
+ frames (List[PIL.Image]): List of frames.
532
+ timestamps (List[float]): List of timestamps.
533
+ """
534
+ if start_time is not None and end_time is not None and end_time - start_time < 1:
535
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
536
+ if os.path.isdir(video_path):
537
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
538
+ if video_path.endswith('.gif'):
539
+ return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
540
+ probe = ffmpeg.probe(video_path)
541
+ duration = float(probe['format']['duration'])
542
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
543
+ w, h = int(video_stream['width']), int(video_stream['height'])
544
+
545
+ kwargs, input_kwargs, output_kwargs = {}, {}, {}
546
+ do_trim = start_time is not None or end_time is not None
547
+ if start_time is not None:
548
+ new_start_time = max(float(video_stream['start_time']), start_time)
549
+ duration -= new_start_time - start_time
550
+ start_time = new_start_time
551
+ else:
552
+ start_time = float(video_stream['start_time'])
553
+ if end_time is not None:
554
+ duration = min(duration, end_time - start_time)
555
+ else:
556
+ duration = duration
557
+ if do_trim:
558
+ kwargs = {'ss': start_time, 't': duration}
559
+ if precise_time:
560
+ output_kwargs.update(kwargs)
561
+ else:
562
+ input_kwargs.update(kwargs)
563
+
564
+ if size is not None:
565
+ scale_factor = size / min(w, h)
566
+ new_w, new_h = round(w * scale_factor), round(h * scale_factor)
567
+ else:
568
+ new_w, new_h = w, h
569
+ new_w = new_w // size_divisible * size_divisible
570
+ new_h = new_h // size_divisible * size_divisible
571
+
572
+ # NOTE: It may result in unexpected number of frames in ffmpeg
573
+ # if calculate the fps directly according to max_frames
574
+ # NOTE: the below lines may hurt the performance
575
+ # if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
576
+ # fps = max_frames / duration * 2
577
+
578
+ stream = ffmpeg.input(video_path, **input_kwargs)
579
+ if fps is not None:
580
+ stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
581
+ if new_w != w or new_h != h:
582
+ stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
583
+ stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
584
+ out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
585
+
586
+ frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
587
+
588
+ if fps is not None:
589
+ timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
590
+ else:
591
+ timestamps = np.linspace(start_time, start_time + duration, len(frames))
592
+
593
+ max_frames = max_frames if max_frames is not None else MAX_FRAMES
594
+ if max_frames is not None and len(frames) > max_frames:
595
+ indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
596
+ frames = frames[indices]
597
+ timestamps = [timestamps[i] for i in indices]
598
+
599
+ if temporal_factor > 1:
600
+ pad_length = temporal_factor - len(frames) % temporal_factor
601
+ frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
602
+ [timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
603
+
604
+ frames = [frame for frame in frames]
605
+ additional_frames = []
606
+ # print('frame_ids', frame_ids)
607
+ if frame_ids is not None:
608
+ vr = VideoReader(video_path, ctx=cpu(0))
609
+ additional_frames = vr.get_batch(frame_ids).asnumpy()
610
+
611
+ return frames, timestamps, additional_frames
612
+
613
+
614
+ def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=None):
615
+ fps = 1 if num_frames is None else None
616
+ # FFmpeg
617
+ frames, timestamps = load_video(video_path, s, e, fps=fps, max_frames=num_frames)
618
+ # Decord
619
+ # frames, timestamps = load_video_from_ids(video_path, s, e, fps=fps, max_frames=num_frames)
620
+
621
+ assert len(frames) == len(timestamps), "Number of frames and timestamps must match."
622
+
623
+ if aspect_ratio == 'pad':
624
+ frames = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in frames]
625
+
626
+ if aspect_ratio == 'qwen2vl':
627
+ frames = [processor.preprocess(frame, return_tensors='pt', image_num=len(frames)) for frame in frames]
628
+ grid_frames = [frames]
629
+ else:
630
+ frames = processor.preprocess(frames, return_tensors='pt', image_num=len(frames))
631
+ grid_frames = [[frames]]
632
+
633
+ return grid_frames, timestamps
634
+
635
+
636
+ def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
637
+ """Tokenize text and multimodal tag to input_ids.
638
+
639
+ Args:
640
+ prompt (str): Text prompt (w/ multimodal tag), e.g., '<video>\nDescribe the video.'
641
+ tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
642
+ multimodal_token (int): Token index corresponding to the multimodal tag.
643
+ """
644
+ multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
645
+ if multimodal_token_index is None:
646
+ input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
647
+ else:
648
+ prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for idx, chunk in enumerate(prompt.split(multimodal_token))]
649
+
650
+ input_ids = []
651
+ for i in range(1, 2 * len(prompt_chunks)):
652
+ if i % 2 == 1:
653
+ input_ids.extend(prompt_chunks[i // 2])
654
+ else:
655
+ input_ids.append(multimodal_token_index)
656
+
657
+ if return_tensors is not None:
658
+ if return_tensors == 'pt':
659
+ return torch.tensor(input_ids, dtype=torch.long)
660
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
661
+ return input_ids
662
+
663
+
664
+ def get_model_name_from_path(model_path):
665
+ model_path = model_path.strip("/")
666
+ model_paths = model_path.split("/")
667
+ if model_paths[-1].startswith('checkpoint-'):
668
+ return model_paths[-2] + "_" + model_paths[-1]
669
+ else:
670
+ return model_paths[-1]
671
+
672
+
673
+ class KeywordsStoppingCriteria(StoppingCriteria):
674
+ def __init__(self, keywords, tokenizer, input_ids):
675
+ self.keywords = keywords
676
+ self.keyword_ids = []
677
+ self.max_keyword_len = 0
678
+ for keyword in keywords:
679
+ cur_keyword_ids = tokenizer(keyword).input_ids
680
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
681
+ cur_keyword_ids = cur_keyword_ids[1:]
682
+ if len(cur_keyword_ids) > self.max_keyword_len:
683
+ self.max_keyword_len = len(cur_keyword_ids)
684
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
685
+ self.tokenizer = tokenizer
686
+ self.start_len = input_ids.shape[1]
687
+
688
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
689
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
690
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
691
+ for keyword_id in self.keyword_ids:
692
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
693
+ return True
694
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
695
+ for keyword in self.keywords:
696
+ if keyword in outputs:
697
+ return True
698
+ return False
699
+
700
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
701
+ outputs = []
702
+ for i in range(output_ids.shape[0]):
703
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
704
+ return all(outputs)
videollama3/model/__init__.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
2
+ # Copyright 2023 Haotian Liu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import warnings
19
+ import shutil
20
+
21
+ import torch
22
+ from transformers import PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
23
+
24
+ from .projector import load_mm_projector
25
+ from .videollama3_qwen2 import Videollama3Qwen2ForCausalLM, Videollama3Qwen2Config
26
+
27
+
28
+ VLLMs = {
29
+ "videollama3_qwen2": Videollama3Qwen2ForCausalLM,
30
+ }
31
+
32
+ VLLMConfigs = {
33
+ "videollama3_qwen2": Videollama3Qwen2Config,
34
+ }
35
+
36
+
37
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", **kwargs):
38
+ if 'token' in kwargs:
39
+ token = kwargs['token']
40
+ else:
41
+ token = None
42
+
43
+ # NOTE: auto device_map by default
44
+ # if want to put model into a single device, you can set device_map={"": "cuda:0"}
45
+ kwargs = {"device_map": device_map, **kwargs}
46
+
47
+ config = AutoConfig.from_pretrained(model_path)
48
+ config._attn_implementation = kwargs.pop('attn_implementation', "flash_attention_2") # default to flash_attention_2
49
+
50
+ torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else kwargs.pop('torch_dtype', torch.float16)
51
+
52
+ if load_8bit:
53
+ kwargs['load_in_8bit'] = True
54
+ elif load_4bit:
55
+ # NOTE: High-version Transformers will report: """ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time."""
56
+ # kwargs['load_in_4bit'] = True
57
+ kwargs['quantization_config'] = BitsAndBytesConfig(
58
+ load_in_4bit=True,
59
+ bnb_4bit_compute_dtype=torch_dtype,
60
+ bnb_4bit_use_double_quant=True,
61
+ bnb_4bit_quant_type='nf4'
62
+ )
63
+ else:
64
+ kwargs['torch_dtype'] = torch_dtype
65
+
66
+ # judge model type
67
+ model_type = config.model_type if hasattr(config, "model_type") else kwargs.pop('model_type', "videollama3_qwen2")
68
+
69
+ # judge pretrain/finetune
70
+ is_alignment = getattr(config, "tune_mm_mlp_adapter", False) or getattr(config, "is_alignment", False)
71
+
72
+ # NOTE: lora/qlora model loading
73
+ if 'lora' in model_name.lower() or 'qlora' in model_name.lower():
74
+ cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
75
+ # NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
76
+ # cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
77
+ model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
78
+
79
+ # NOTE: remove qlora training quantization config
80
+ if hasattr(lora_cfg_pretrained, 'quantization_config'):
81
+ del lora_cfg_pretrained.quantization_config
82
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
83
+ print('Loading VideoLLaMA from base model...')
84
+
85
+ if 'qwen2' in model_base.lower():
86
+ model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
87
+ else:
88
+ model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
89
+
90
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
91
+ if model.lm_head.weight.shape[0] != token_num:
92
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
93
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
94
+
95
+ print('Loading additional VideoLLaMA weights...')
96
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
97
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
98
+ else:
99
+ # this is probably from HF Hub
100
+ from huggingface_hub import hf_hub_download
101
+ def load_from_hf(repo_id, filename, subfolder=None):
102
+ cache_file = hf_hub_download(
103
+ repo_id=repo_id,
104
+ filename=filename,
105
+ subfolder=subfolder)
106
+ return torch.load(cache_file, map_location='cpu')
107
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
108
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
109
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
110
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
111
+ model.load_state_dict(non_lora_trainables, strict=False)
112
+
113
+ from peft import PeftModel
114
+ print('Loading LoRA weights...')
115
+ model = PeftModel.from_pretrained(model, model_path)
116
+ print('Merging LoRA weights...')
117
+ model = model.merge_and_unload()
118
+ print('Model is loaded...')
119
+ elif model_base is not None or '-base' in model_name.lower() or is_alignment:
120
+ # NOTE: Base/Pretrain model loading
121
+ print('Loading VideoLLaMA 2 from base model...')
122
+ cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
123
+ # NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
124
+ # cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
125
+ model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
126
+
127
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
128
+
129
+ if model_type in ['videollama3', 'videollama3_qwen2']:
130
+ model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
131
+ else:
132
+ model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
133
+
134
+ # NOTE; loading vision-language projector
135
+ # * old codes for loading local mm_projector.bin
136
+ # mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
137
+ # mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
138
+ # model.load_state_dict(mm_projector_weights, strict=False)
139
+ # * new codes which supports loading mm_projector.bin both offline and online
140
+ mm_projector_weights = load_mm_projector(model_path, token=token)
141
+ model.load_state_dict(mm_projector_weights, strict=False)
142
+ elif 'videollama' in model_type:
143
+ # NOTE: SFT model loading
144
+ print(model_path)
145
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
146
+
147
+ if model_type in ['videollama3_qwen2']:
148
+ model = Videollama3Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
149
+ else:
150
+ model = Videollama3Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
151
+ else:
152
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, token=token)
153
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config, **kwargs)
154
+
155
+ processor = None
156
+
157
+ if "videollama" in model_type:
158
+ vision_encoder = model.get_vision_encoder()
159
+ processor = vision_encoder.image_processor
160
+
161
+ if hasattr(model.config, "max_sequence_length"):
162
+ context_len = model.config.max_sequence_length
163
+ else:
164
+ context_len = 2048
165
+
166
+ return tokenizer, model, processor, context_len
videollama3/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (4.04 kB). View file
 
videollama3/model/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
videollama3/model/__pycache__/processor.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
videollama3/model/__pycache__/projector.cpython-310.pyc ADDED
Binary file (5.11 kB). View file
 
videollama3/model/__pycache__/region_encoder.cpython-310.pyc ADDED
Binary file (3.43 kB). View file
 
videollama3/model/__pycache__/videollama3_arch.cpython-310.pyc ADDED
Binary file (9.74 kB). View file
 
videollama3/model/__pycache__/videollama3_qwen2.cpython-310.pyc ADDED
Binary file (4.2 kB). View file
 
videollama3/model/damovl_encoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .configuration_damovl_encoder import DAMOVLVisionConfig
2
+ from .image_processing import DAMOVLImageProcessor
3
+ from .modeling_damovl_encoder import DAMOVLVisionModel
videollama3/model/damovl_encoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (407 Bytes). View file
 
videollama3/model/damovl_encoder/__pycache__/configuration_damovl_encoder.cpython-310.pyc ADDED
Binary file (1.96 kB). View file
 
videollama3/model/damovl_encoder/__pycache__/image_processing.cpython-310.pyc ADDED
Binary file (16.7 kB). View file
 
videollama3/model/damovl_encoder/__pycache__/modeling_damovl_encoder.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
videollama3/model/damovl_encoder/configuration_damovl_encoder.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2VL model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DAMOVLVisionConfig(PretrainedConfig):
28
+ model_type = "damovl"
29
+
30
+ def __init__(
31
+ self,
32
+ hidden_size=768,
33
+ intermediate_size=3072,
34
+ num_hidden_layers=12,
35
+ num_attention_heads=12,
36
+ num_channels=3,
37
+ patch_size=16,
38
+ hidden_act="gelu_pytorch_tanh",
39
+ layer_norm_eps=1e-6,
40
+ attention_dropout=0.0,
41
+ spatial_merge_size=1,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(**kwargs)
45
+
46
+ self.hidden_size = hidden_size
47
+ self.intermediate_size = intermediate_size
48
+ self.num_hidden_layers = num_hidden_layers
49
+ self.num_attention_heads = num_attention_heads
50
+ self.num_channels = num_channels
51
+ self.patch_size = patch_size
52
+ self.attention_dropout = attention_dropout
53
+ self.layer_norm_eps = layer_norm_eps
54
+ self.hidden_act = hidden_act
55
+ self.spatial_merge_size = spatial_merge_size
56
+
57
+ @classmethod
58
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
59
+ cls._set_token_in_kwargs(kwargs)
60
+
61
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
62
+
63
+ # config_dict = config_dict["vision_config"]
64
+
65
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
66
+ logger.warning(
67
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
68
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
69
+ )
70
+
71
+ return cls.from_dict(config_dict, **kwargs)
videollama3/model/damovl_encoder/image_processing.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Image processor class for Qwen2-VL."""
21
+
22
+ import math
23
+ from typing import Dict, List, Optional, Union
24
+
25
+ import numpy as np
26
+
27
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
28
+ from transformers.image_transforms import (
29
+ convert_to_rgb,
30
+ resize,
31
+ to_channel_dimension_format,
32
+ )
33
+ from transformers.image_utils import (
34
+ OPENAI_CLIP_MEAN,
35
+ OPENAI_CLIP_STD,
36
+ ChannelDimension,
37
+ ImageInput,
38
+ PILImageResampling,
39
+ VideoInput,
40
+ get_image_size,
41
+ infer_channel_dimension_format,
42
+ is_scaled_image,
43
+ is_valid_image,
44
+ make_list_of_images,
45
+ to_numpy_array,
46
+ valid_images,
47
+ validate_preprocess_arguments,
48
+ )
49
+ from transformers.utils import TensorType, is_vision_available, logging
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ if is_vision_available():
56
+ from PIL import Image
57
+
58
+
59
+ def make_batched_images(images) -> List[List[ImageInput]]:
60
+ """
61
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
62
+
63
+ Args:
64
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
65
+ The input image.
66
+
67
+ Returns:
68
+ list: A list of images.
69
+ """
70
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
71
+ return [img for img_list in images for img in img_list]
72
+
73
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
74
+ return images
75
+
76
+ elif is_valid_image(images):
77
+ return [images]
78
+
79
+ raise ValueError(f"Could not make batched images from {images}")
80
+
81
+
82
+ # Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
83
+ def make_batched_videos(videos) -> List[VideoInput]:
84
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
85
+ return videos
86
+
87
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
88
+ if isinstance(videos[0], Image.Image):
89
+ return [videos]
90
+ elif len(videos[0].shape) == 4:
91
+ return [list(video) for video in videos]
92
+
93
+ elif is_valid_image(videos) and len(videos.shape) == 4:
94
+ return [list(videos)]
95
+
96
+ raise ValueError(f"Could not make batched video from {videos}")
97
+
98
+
99
+ def smart_resize(
100
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
101
+ ):
102
+ """Rescales the image so that the following conditions are met:
103
+
104
+ 1. Both dimensions (height and width) are divisible by 'factor'.
105
+
106
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
107
+
108
+ 3. The aspect ratio of the image is maintained as closely as possible.
109
+
110
+ """
111
+ if height < factor or width < factor:
112
+ scale = factor / min(height, width)
113
+ width = round(scale * width)
114
+ height = round(scale * height)
115
+ elif max(height, width) / min(height, width) > 200:
116
+ raise ValueError(
117
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
118
+ )
119
+ h_bar = round(height / factor) * factor
120
+ w_bar = round(width / factor) * factor
121
+ if h_bar * w_bar > max_pixels:
122
+ beta = math.sqrt((height * width) / max_pixels)
123
+ h_bar = math.floor(height / beta / factor) * factor
124
+ w_bar = math.floor(width / beta / factor) * factor
125
+ elif h_bar * w_bar < min_pixels:
126
+ beta = math.sqrt(min_pixels / (height * width))
127
+ h_bar = math.ceil(height * beta / factor) * factor
128
+ w_bar = math.ceil(width * beta / factor) * factor
129
+ return h_bar, w_bar
130
+
131
+
132
+ class DAMOVLImageProcessor(BaseImageProcessor):
133
+ r"""
134
+ Constructs a DAMOVL image processor that dynamically resizes images based on the original images.
135
+
136
+ Args:
137
+ do_resize (`bool`, *optional*, defaults to `True`):
138
+ Whether to resize the image's (height, width) dimensions.
139
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
140
+ Resampling filter to use when resizing the image.
141
+ do_rescale (`bool`, *optional*, defaults to `True`):
142
+ Whether to rescale the image by the specified scale `rescale_factor`.
143
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
144
+ Scale factor to use if rescaling the image.
145
+ do_normalize (`bool`, *optional*, defaults to `True`):
146
+ Whether to normalize the image.
147
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
148
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
149
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
150
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
151
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
152
+ Whether to convert the image to RGB.
153
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
154
+ The min pixels of the image to resize the image.
155
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
156
+ The max pixels of the image to resize the image.
157
+ patch_size (`int`, *optional*, defaults to 14):
158
+ The spacial patch size of the vision encoder.
159
+ temporal_patch_size (`int`, *optional*, defaults to 2):
160
+ The temporal patch size of the vision encoder.
161
+ merge_size (`int`, *optional*, defaults to 2):
162
+ The merge size of the vision encoder to llm encoder.
163
+ """
164
+
165
+ model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
166
+
167
+ def __init__(
168
+ self,
169
+ do_resize: bool = True,
170
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
171
+ do_rescale: bool = True,
172
+ rescale_factor: Union[int, float] = 1 / 255,
173
+ do_normalize: bool = True,
174
+ image_mean: Optional[Union[float, List[float]]] = None,
175
+ image_std: Optional[Union[float, List[float]]] = None,
176
+ do_convert_rgb: bool = True,
177
+ min_pixels: int = 56 * 56,
178
+ max_pixels: int = 14 * 14 * 9477,
179
+ patch_size: int = 14,
180
+ merge_size: int = 1,
181
+ **kwargs,
182
+ ) -> None:
183
+ super().__init__(**kwargs)
184
+ self.do_resize = do_resize
185
+ self.resample = resample
186
+ self.do_rescale = do_rescale
187
+ self.rescale_factor = rescale_factor
188
+ self.do_normalize = do_normalize
189
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
190
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
191
+ self.min_pixels = min_pixels
192
+ self.max_pixels = max_pixels
193
+ self.patch_size = patch_size
194
+ self.merge_size = merge_size
195
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
196
+ self.do_convert_rgb = do_convert_rgb
197
+
198
+ self.temporal_patch_size = 1
199
+
200
+ def _preprocess(
201
+ self,
202
+ images: Union[ImageInput, VideoInput],
203
+ do_resize: bool = None,
204
+ resample: PILImageResampling = None,
205
+ do_rescale: bool = None,
206
+ rescale_factor: float = None,
207
+ do_normalize: bool = None,
208
+ image_mean: Optional[Union[float, List[float]]] = None,
209
+ image_std: Optional[Union[float, List[float]]] = None,
210
+ do_convert_rgb: bool = None,
211
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
212
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
213
+ num_images: Optional[int] = 1,
214
+ image_downsampling: Optional[int] = None,
215
+ ):
216
+ """
217
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
218
+
219
+ Args:
220
+ images (`ImageInput`):
221
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
222
+ vision_info (`List[Dict]`, *optional*):
223
+ Optional list of dictionaries containing additional information about vision inputs.
224
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
225
+ Whether to resize the image.
226
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
227
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
228
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
229
+ Whether to rescale the image.
230
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
231
+ Scale factor to use if rescaling the image.
232
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
233
+ Whether to normalize the image.
234
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
235
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
236
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
237
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
238
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
239
+ Whether to convert the image to RGB.
240
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
241
+ The channel dimension format for the output image. Can be one of:
242
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
243
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
244
+ - Unset: Use the channel dimension format of the input image.
245
+ input_data_format (`ChannelDimension` or `str`, *optional*):
246
+ The channel dimension format for the input image. Can be one of:
247
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
248
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
249
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
250
+ """
251
+ images = make_list_of_images(images)
252
+
253
+ if do_convert_rgb:
254
+ images = [convert_to_rgb(image) for image in images]
255
+
256
+ # All transformations expect numpy arrays.
257
+ images = [to_numpy_array(image) for image in images]
258
+
259
+ if is_scaled_image(images[0]) and do_rescale:
260
+ logger.warning_once(
261
+ "It looks like you are trying to rescale already rescaled images. If the input"
262
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
263
+ )
264
+ if input_data_format is None:
265
+ # We assume that all images have the same channel dimension format.
266
+ input_data_format = infer_channel_dimension_format(images[0])
267
+
268
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
269
+ resized_height, resized_width = height, width
270
+ processed_images = []
271
+ for image in images:
272
+ if do_resize:
273
+ max_pixels = int(self.max_pixels / (self.merge_size / image_downsampling)**2)
274
+ resized_height, resized_width = smart_resize(
275
+ height,
276
+ width,
277
+ factor=self.patch_size * image_downsampling,
278
+ min_pixels=self.min_pixels,
279
+ max_pixels=int(max_pixels // num_images),
280
+ )
281
+ image = resize(
282
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
283
+ )
284
+
285
+ if do_rescale:
286
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
287
+
288
+ if do_normalize:
289
+ image = self.normalize(
290
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
291
+ )
292
+
293
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
294
+ processed_images.append(image)
295
+
296
+ patches = np.array(processed_images)
297
+ if data_format == ChannelDimension.LAST:
298
+ patches = patches.transpose(0, 3, 1, 2)
299
+
300
+ channel = patches.shape[1]
301
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
302
+ patches = patches.reshape(
303
+ channel,
304
+ grid_h // image_downsampling,
305
+ image_downsampling,
306
+ self.patch_size,
307
+ grid_w // image_downsampling,
308
+ image_downsampling,
309
+ self.patch_size,
310
+ )
311
+ patches = patches.transpose(1, 4, 2, 5, 0, 3, 6)
312
+ flatten_patches = patches.reshape(
313
+ grid_h * grid_w, channel * self.patch_size * self.patch_size
314
+ )
315
+ # print('image_downsampling', image_downsampling)
316
+ # flatten_patches1 = flatten_patches.reshape(grid_h, grid_w, channel, -1)
317
+ # from matplotlib import pyplot as plt
318
+ # plt.imshow(flatten_patches1[:,:,:,0])
319
+ # plt.savefig('8.png')
320
+
321
+ return flatten_patches, (1, grid_h, grid_w)
322
+
323
+ def preprocess(
324
+ self,
325
+ images: ImageInput,
326
+ videos: VideoInput = None,
327
+ do_resize: bool = None,
328
+ size: Dict[str, int] = None,
329
+ resample: PILImageResampling = None,
330
+ do_rescale: bool = None,
331
+ rescale_factor: float = None,
332
+ do_normalize: bool = None,
333
+ image_mean: Optional[Union[float, List[float]]] = None,
334
+ image_std: Optional[Union[float, List[float]]] = None,
335
+ do_convert_rgb: bool = None,
336
+ return_tensors: Optional[Union[str, TensorType]] = None,
337
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
338
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
339
+ num_images: Optional[int] = 1,
340
+ image_downsampling: Optional[int] = None,
341
+ ):
342
+ """
343
+ Args:
344
+ images (`ImageInput`):
345
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
346
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
347
+ videos (`VideoInput`):
348
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
349
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
350
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
351
+ Whether to resize the image.
352
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
353
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
354
+ the longest edge resized to keep the input aspect ratio.
355
+ resample (`int`, *optional*, defaults to `self.resample`):
356
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
357
+ has an effect if `do_resize` is set to `True`.
358
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
359
+ Whether to rescale the image.
360
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
361
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
362
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
363
+ Whether to normalize the image.
364
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
365
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
366
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
367
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
368
+ `True`.
369
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
370
+ Whether to convert the image to RGB.
371
+ return_tensors (`str` or `TensorType`, *optional*):
372
+ The type of tensors to return. Can be one of:
373
+ - Unset: Return a list of `np.ndarray`.
374
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
375
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
376
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
377
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
378
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
379
+ The channel dimension format for the output image. Can be one of:
380
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
381
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
382
+ - Unset: Use the channel dimension format of the input image.
383
+ input_data_format (`ChannelDimension` or `str`, *optional*):
384
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
385
+ from the input image. Can be one of:
386
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
387
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
388
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
389
+
390
+ """
391
+ do_resize = do_resize if do_resize is not None else self.do_resize
392
+ size = size if size is not None else self.size
393
+ resample = resample if resample is not None else self.resample
394
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
395
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
396
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
397
+ image_mean = image_mean if image_mean is not None else self.image_mean
398
+ image_std = image_std if image_std is not None else self.image_std
399
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
400
+ image_downsampling = image_downsampling if image_downsampling is not None else self.merge_size
401
+
402
+ if images is not None:
403
+ images = make_batched_images(images)
404
+ if videos is not None:
405
+ videos = make_batched_videos(videos)
406
+
407
+ if images is not None and not valid_images(images):
408
+ raise ValueError(
409
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
410
+ "torch.Tensor, tf.Tensor or jax.ndarray."
411
+ )
412
+
413
+ validate_preprocess_arguments(
414
+ rescale_factor=rescale_factor,
415
+ do_normalize=do_normalize,
416
+ image_mean=image_mean,
417
+ image_std=image_std,
418
+ do_resize=do_resize,
419
+ size=size,
420
+ resample=resample,
421
+ )
422
+
423
+ if images is not None:
424
+ pixel_values, vision_grid_thws = [], []
425
+ for image in images:
426
+ patches, image_grid_thw = self._preprocess(
427
+ image,
428
+ do_resize=do_resize,
429
+ resample=resample,
430
+ do_rescale=do_rescale,
431
+ rescale_factor=rescale_factor,
432
+ do_normalize=do_normalize,
433
+ image_mean=image_mean,
434
+ image_std=image_std,
435
+ data_format=data_format,
436
+ do_convert_rgb=do_convert_rgb,
437
+ input_data_format=input_data_format,
438
+ num_images=num_images,
439
+ image_downsampling=image_downsampling,
440
+ )
441
+ pixel_values.extend(patches)
442
+ vision_grid_thws.append(image_grid_thw)
443
+ pixel_values = np.array(pixel_values)
444
+ vision_grid_thws = np.array(vision_grid_thws)
445
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
446
+
447
+ assert videos is None, "Not support video for now."
448
+ # NOTE: not support video for now
449
+ # if videos is not None:
450
+ # pixel_values, vision_grid_thws = [], []
451
+ # for images in videos:
452
+ # patches, video_grid_thw = self._preprocess(
453
+ # images,
454
+ # do_resize=do_resize,
455
+ # resample=resample,
456
+ # do_rescale=do_rescale,
457
+ # rescale_factor=rescale_factor,
458
+ # do_normalize=do_normalize,
459
+ # image_mean=image_mean,
460
+ # image_std=image_std,
461
+ # data_format=data_format,
462
+ # do_convert_rgb=do_convert_rgb,
463
+ # input_data_format=input_data_format,
464
+ # image_num=image_num,
465
+ # )
466
+ # pixel_values.extend(patches)
467
+ # vision_grid_thws.append(video_grid_thw)
468
+ # pixel_values = np.array(pixel_values)
469
+ # vision_grid_thws = np.array(vision_grid_thws)
470
+ # data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
471
+
472
+ return BatchFeature(data=data, tensor_type=return_tensors)
videollama3/model/damovl_encoder/modeling_damovl_encoder.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Siglip model."""
16
+
17
+ import math
18
+ import warnings
19
+ from dataclasses import dataclass
20
+ from typing import Any, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
27
+ from torch.nn.init import _calculate_fan_in_and_fan_out
28
+ import torch.nn.functional as F
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (add_start_docstrings,
33
+ add_start_docstrings_to_model_forward,
34
+ is_flash_attn_2_available,
35
+ is_flash_attn_greater_or_equal_2_10, logging,
36
+ replace_return_docstrings)
37
+ from .configuration_damovl_encoder import DAMOVLVisionConfig
38
+
39
+
40
+ if is_flash_attn_2_available():
41
+ from flash_attn import flash_attn_varlen_func
42
+ from transformers.modeling_flash_attention_utils import \
43
+ _flash_attention_forward
44
+ else:
45
+ flash_attn_varlen_func = None
46
+
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ def _trunc_normal_(tensor, mean, std, a, b):
52
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
53
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
54
+ def norm_cdf(x):
55
+ # Computes standard normal cumulative distribution function
56
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
57
+
58
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
59
+ warnings.warn(
60
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
61
+ "The distribution of values may be incorrect.",
62
+ stacklevel=2,
63
+ )
64
+
65
+ # Values are generated by using a truncated uniform distribution and
66
+ # then using the inverse CDF for the normal distribution.
67
+ # Get upper and lower cdf values
68
+ l = norm_cdf((a - mean) / std)
69
+ u = norm_cdf((b - mean) / std)
70
+
71
+ # Uniformly fill tensor with values from [l, u], then translate to
72
+ # [2l-1, 2u-1].
73
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
74
+
75
+ # Use inverse cdf transform for normal distribution to get truncated
76
+ # standard normal
77
+ tensor.erfinv_()
78
+
79
+ # Transform to proper mean, std
80
+ tensor.mul_(std * math.sqrt(2.0))
81
+ tensor.add_(mean)
82
+
83
+ # Clamp to ensure it's in the proper range
84
+ tensor.clamp_(min=a, max=b)
85
+
86
+
87
+ def trunc_normal_tf_(
88
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
89
+ ) -> torch.Tensor:
90
+ """Fills the input Tensor with values drawn from a truncated
91
+ normal distribution. The values are effectively drawn from the
92
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
93
+ with values outside :math:`[a, b]` redrawn until they are within
94
+ the bounds. The method used for generating the random values works
95
+ best when :math:`a \\leq \text{mean} \\leq b`.
96
+
97
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
98
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
99
+ and the result is subsequently scaled and shifted by the mean and std args.
100
+
101
+ Args:
102
+ tensor: an n-dimensional `torch.Tensor`
103
+ mean: the mean of the normal distribution
104
+ std: the standard deviation of the normal distribution
105
+ a: the minimum cutoff value
106
+ b: the maximum cutoff value
107
+ """
108
+ with torch.no_grad():
109
+ _trunc_normal_(tensor, 0, 1.0, a, b)
110
+ tensor.mul_(std).add_(mean)
111
+
112
+
113
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
114
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
115
+ if mode == "fan_in":
116
+ denom = fan_in
117
+ elif mode == "fan_out":
118
+ denom = fan_out
119
+ elif mode == "fan_avg":
120
+ denom = (fan_in + fan_out) / 2
121
+
122
+ variance = scale / denom
123
+
124
+ if distribution == "truncated_normal":
125
+ # constant is stddev of standard normal truncated to (-2, 2)
126
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
127
+ elif distribution == "normal":
128
+ with torch.no_grad():
129
+ tensor.normal_(std=math.sqrt(variance))
130
+ elif distribution == "uniform":
131
+ bound = math.sqrt(3 * variance)
132
+ with torch.no_grad():
133
+ tensor.uniform_(-bound, bound)
134
+ else:
135
+ raise ValueError(f"invalid distribution {distribution}")
136
+
137
+
138
+ def lecun_normal_(tensor):
139
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
140
+
141
+
142
+ def default_flax_embed_init(tensor):
143
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
144
+
145
+
146
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
147
+ def rotate_half(x):
148
+ """Rotates half the hidden dims of the input."""
149
+ x1 = x[..., : x.shape[-1] // 2]
150
+ x2 = x[..., x.shape[-1] // 2 :]
151
+ return torch.cat((-x2, x1), dim=-1)
152
+
153
+
154
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
155
+ orig_dtype = tensor.dtype
156
+ tensor = tensor.float()
157
+ cos = freqs.cos()
158
+ sin = freqs.sin()
159
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
160
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
161
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
162
+ output = output.to(orig_dtype)
163
+ return output
164
+
165
+
166
+ class VisionRotaryEmbedding(nn.Module):
167
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
168
+ super().__init__()
169
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
170
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
171
+
172
+ def forward(self, seqlen: int) -> torch.Tensor:
173
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
174
+ freqs = torch.outer(seq, self.inv_freq)
175
+ return freqs
176
+
177
+
178
+ class DAMOVLVisionEmbeddings(nn.Module):
179
+ def __init__(self, config: DAMOVLVisionConfig):
180
+ super().__init__()
181
+ self.config = config
182
+ self.embed_dim = config.hidden_size
183
+ self.patch_size = config.patch_size
184
+
185
+ self.patch_embedding = nn.Conv2d(
186
+ in_channels=config.num_channels,
187
+ out_channels=self.embed_dim,
188
+ kernel_size=self.patch_size,
189
+ stride=self.patch_size,
190
+ padding="valid",
191
+ )
192
+
193
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194
+ hidden_states = hidden_states.view(
195
+ -1, self.config.num_channels, self.patch_size, self.patch_size
196
+ )
197
+ patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
198
+ # embeddings = patch_embeds.flatten(2).transpose(1, 2)
199
+ embeddings = patch_embeds.view(-1, self.embed_dim)
200
+
201
+ return embeddings
202
+
203
+
204
+ class VisionAttention(nn.Module):
205
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
206
+
207
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
208
+ def __init__(self, config):
209
+ super().__init__()
210
+ self.config = config
211
+ self.embed_dim = config.hidden_size
212
+ self.num_heads = config.num_attention_heads
213
+ self.head_dim = self.embed_dim // self.num_heads
214
+ if self.head_dim * self.num_heads != self.embed_dim:
215
+ raise ValueError(
216
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
217
+ f" {self.num_heads})."
218
+ )
219
+ self.scale = self.head_dim**-0.5
220
+ self.dropout = config.attention_dropout
221
+
222
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
223
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
224
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
225
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.Tensor,
230
+ cu_seqlens: torch.Tensor,
231
+ rotary_pos_emb: torch.Tensor = None,
232
+ ) -> torch.Tensor:
233
+ """Input shape: Time x Channel"""
234
+
235
+ q_len, _ = hidden_states.size()
236
+
237
+ query_states = self.q_proj(hidden_states)
238
+ key_states = self.k_proj(hidden_states)
239
+ value_states = self.v_proj(hidden_states)
240
+
241
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
242
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
243
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
244
+
245
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
246
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
247
+
248
+ attention_mask = torch.zeros([1, q_len, q_len], device=q.device, dtype=torch.bool)
249
+ for i in range(1, len(cu_seqlens)):
250
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
251
+
252
+ query_states = query_states.transpose(0, 1)
253
+ key_states = key_states.transpose(0, 1)
254
+ value_states = value_states.transpose(0, 1)
255
+
256
+ attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
257
+ attn_weights = attn_weights + attention_mask
258
+
259
+ # upcast attention to fp32
260
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
261
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
262
+ attn_output = torch.matmul(attn_weights, value_states)
263
+
264
+ attn_output = attn_output.transpose(0, 1)
265
+ attn_output = attn_output.reshape(q_len, -1)
266
+ attn_output = self.out_proj(attn_output)
267
+
268
+ return attn_output
269
+
270
+
271
+ class VisionFlashAttention2(VisionAttention):
272
+ def __init__(self, *args, **kwargs):
273
+ super().__init__(*args, **kwargs)
274
+
275
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.Tensor,
279
+ cu_seqlens: torch.Tensor,
280
+ rotary_pos_emb: torch.Tensor = None,
281
+ ) -> torch.Tensor:
282
+ q_len, _ = hidden_states.size()
283
+
284
+ query_states = self.q_proj(hidden_states)
285
+ key_states = self.k_proj(hidden_states)
286
+ value_states = self.v_proj(hidden_states)
287
+
288
+ # Flash attention requires the input to have the shape
289
+ # batch_size x seq_length x head_dim x hidden_dim
290
+ # therefore we just need to keep the original shape
291
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
292
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
293
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
294
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
295
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
296
+
297
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
298
+ attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
299
+ q_len, -1
300
+ )
301
+ attn_output = self.out_proj(attn_output)
302
+
303
+ return attn_output
304
+
305
+
306
+ class VisionSdpaAttention(VisionAttention):
307
+ def forward(
308
+ self,
309
+ hidden_states: torch.Tensor,
310
+ cu_seqlens: torch.Tensor,
311
+ rotary_pos_emb: torch.Tensor = None,
312
+ ) -> torch.Tensor:
313
+ if output_attentions:
314
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
315
+ logger.warning_once(
316
+ "DAMOVLVisionModel is using VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
317
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
318
+ )
319
+ return super().forward(
320
+ hidden_states=hidden_states,
321
+ cu_seqlens=cu_seqlens,
322
+ rotary_pos_emb=rotary_pos_emb,
323
+ )
324
+
325
+ seq_length = hidden_states.shape[0]
326
+ query_states = self.q_proj(hidden_states)
327
+ key_states = self.k_proj(hidden_states)
328
+ value_states = self.v_proj(hidden_states)
329
+
330
+ query_states = query_states.view(q_len, self.num_heads, self.head_dim)
331
+ key_states = key_states.view(q_len, self.num_heads, self.head_dim)
332
+ value_states = value_states.view(q_len, self.num_heads, self.head_dim)
333
+
334
+ query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
335
+ key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
336
+
337
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
338
+ for i in range(1, len(cu_seqlens)):
339
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
340
+
341
+ query_states = query_states.transpose(0, 1)
342
+ key_states = key_states.transpose(0, 1)
343
+ value_states = value_states.transpose(0, 1)
344
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
345
+ attn_output = attn_output.transpose(0, 1)
346
+ attn_output = attn_output.reshape(seq_length, -1)
347
+ attn_output = self.proj(attn_output)
348
+ return attn_output
349
+
350
+
351
+ DAMOVL_VISION_ATTENTION_CLASSES = {
352
+ "eager": VisionAttention,
353
+ "flash_attention_2": VisionFlashAttention2,
354
+ "sdpa": VisionSdpaAttention,
355
+ }
356
+
357
+
358
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->DAMOVL
359
+ class DAMOVLVisionMLP(nn.Module):
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.config = config
363
+ self.activation_fn = ACT2FN[config.hidden_act]
364
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
365
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
366
+
367
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
368
+ hidden_states = self.fc1(hidden_states)
369
+ hidden_states = self.activation_fn(hidden_states)
370
+ hidden_states = self.fc2(hidden_states)
371
+ return hidden_states
372
+
373
+
374
+ class DAMOVLVisionEncoderLayer(nn.Module):
375
+ def __init__(self, config: DAMOVLVisionConfig):
376
+ super().__init__()
377
+ self.embed_dim = config.hidden_size
378
+ self.self_attn = DAMOVL_VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
379
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
380
+ self.mlp = DAMOVLVisionMLP(config)
381
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
382
+
383
+ # Ignore copy
384
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
385
+ hidden_states = hidden_states + self.self_attn(
386
+ self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
387
+ )
388
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
389
+ return hidden_states
390
+
391
+
392
+ class DAMOVLPreTrainedModel(PreTrainedModel):
393
+ """
394
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
395
+ models.
396
+ """
397
+
398
+ config_class = DAMOVLVisionConfig
399
+ base_model_prefix = "damovl"
400
+ supports_gradient_checkpointing = True
401
+ _no_split_modules = [
402
+ "DAMOVLVisionEncoderLayer",
403
+ "DAMOVLVisionEmbeddings",
404
+ ]
405
+ _supports_flash_attn_2 = True
406
+ _supports_sdpa = True
407
+
408
+ def _init_weights(self, module):
409
+ """Initialize the weights"""
410
+ if isinstance(module, nn.Embedding):
411
+ default_flax_embed_init(module.weight)
412
+ elif isinstance(module, VisionAttention):
413
+ nn.init.xavier_uniform_(module.q_proj.weight)
414
+ nn.init.xavier_uniform_(module.k_proj.weight)
415
+ nn.init.xavier_uniform_(module.v_proj.weight)
416
+ nn.init.xavier_uniform_(module.out_proj.weight)
417
+ nn.init.zeros_(module.q_proj.bias)
418
+ nn.init.zeros_(module.k_proj.bias)
419
+ nn.init.zeros_(module.v_proj.bias)
420
+ nn.init.zeros_(module.out_proj.bias)
421
+ elif isinstance(module, DAMOVLVisionMLP):
422
+ nn.init.xavier_uniform_(module.fc1.weight)
423
+ nn.init.xavier_uniform_(module.fc2.weight)
424
+ nn.init.normal_(module.fc1.bias, std=1e-6)
425
+ nn.init.normal_(module.fc2.bias, std=1e-6)
426
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
427
+ lecun_normal_(module.weight)
428
+ if module.bias is not None:
429
+ nn.init.zeros_(module.bias)
430
+ elif isinstance(module, nn.LayerNorm):
431
+ module.bias.data.zero_()
432
+ module.weight.data.fill_(1.0)
433
+
434
+
435
+ class DAMOVLVisionEncoder(nn.Module):
436
+ def __init__(self, config: DAMOVLVisionConfig):
437
+ super().__init__()
438
+ self.config = config
439
+ head_dim = config.hidden_size // config.num_attention_heads
440
+ self.spatial_merge_size = config.spatial_merge_size
441
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
442
+ self.layers = nn.ModuleList([DAMOVLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
443
+ self.gradient_checkpointing = False
444
+
445
+ def rot_pos_emb(self, grid_thw, strides):
446
+ pos_ids = []
447
+ for (t, h, w), stride in zip(grid_thw, strides):
448
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
449
+ hpos_ids = hpos_ids.reshape(
450
+ h // stride,
451
+ stride,
452
+ w // stride,
453
+ stride,
454
+ )
455
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
456
+ hpos_ids = hpos_ids.flatten()
457
+
458
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
459
+ wpos_ids = wpos_ids.reshape(
460
+ h // stride,
461
+ stride,
462
+ w // stride,
463
+ stride,
464
+ )
465
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
466
+ wpos_ids = wpos_ids.flatten()
467
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
468
+ pos_ids = torch.cat(pos_ids, dim=0)
469
+ max_grid_size = grid_thw[:, 1:].max()
470
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
471
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
472
+ return rotary_pos_emb
473
+
474
+ def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
475
+ # BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx`
476
+ # rotary_pos_emb = []
477
+ # for thw in grid_thws:
478
+ # rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0))
479
+ # rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0)
480
+ # grid_thws = torch.cat(grid_thws, dim = 0)
481
+
482
+ # new version of creating rotary position embedding
483
+ # grid_thws shapes like [batch_flatten_image_num, 3]
484
+ # grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py`
485
+ rotary_pos_emb = self.rot_pos_emb(grid_thws, strides)
486
+
487
+ cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32)
488
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
489
+
490
+ for blk in self.layers:
491
+ if self.gradient_checkpointing and self.training:
492
+ hidden_states = self._gradient_checkpointing_func(
493
+ blk.__call__,
494
+ hidden_states,
495
+ cu_seqlens,
496
+ rotary_pos_emb
497
+ )
498
+ else:
499
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
500
+ return hidden_states
501
+
502
+
503
+ class DAMOVLVisionTransformer(nn.Module):
504
+ def __init__(self, config: DAMOVLVisionConfig):
505
+ super().__init__()
506
+ self.config = config
507
+ embed_dim = config.hidden_size
508
+
509
+ self.embeddings = DAMOVLVisionEmbeddings(config)
510
+ self.encoder = DAMOVLVisionEncoder(config)
511
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
512
+
513
+ def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
514
+
515
+ # print(hidden_states)
516
+
517
+ # hidden_states = torch.cat(hidden_states, dim = 1)
518
+
519
+ hidden_states = self.embeddings(hidden_states)
520
+ hidden_states = self.encoder(hidden_states, grid_thws, strides)
521
+ hidden_states = self.post_layernorm(hidden_states)
522
+
523
+ return hidden_states
524
+
525
+
526
+ class DAMOVLVisionModel(DAMOVLPreTrainedModel):
527
+ config_class = DAMOVLVisionConfig
528
+ main_input_name = "hidden_states"
529
+
530
+ def __init__(self, config: DAMOVLVisionConfig):
531
+ super().__init__(config)
532
+
533
+ self.vision_model = DAMOVLVisionTransformer(config)
534
+
535
+ # Initialize weights and apply final processing
536
+ self.post_init()
537
+
538
+ def get_input_embeddings(self) -> nn.Module:
539
+ return self.vision_model.embeddings.patch_embedding
540
+
541
+ def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
542
+ return self.vision_model(hidden_states=hidden_states, grid_thws=grid_thws, strides=strides)
videollama3/model/encoder.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (CLIPImageProcessor, CLIPVisionConfig,
6
+ CLIPVisionModel, SiglipImageProcessor,
7
+ SiglipVisionConfig, SiglipVisionModel)
8
+
9
+ from .qwen2vl_encoder import (Qwen2VisionTransformerPretrainedModel,
10
+ Qwen2VLImageProcessor, Qwen2VLVisionConfig)
11
+
12
+ from .damovl_encoder import (DAMOVLImageProcessor, DAMOVLVisionModel)
13
+
14
+
15
+ class CLIPVisionEncoder(nn.Module):
16
+
17
+ def __init__(self, vision_encoder, args, delay_load=False):
18
+ super().__init__()
19
+
20
+ self.is_loaded = False
21
+
22
+ self.vision_encoder_name = vision_encoder
23
+ self.select_layer = args.mm_vision_select_layer
24
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
25
+
26
+ if not delay_load:
27
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
28
+ self.load_model()
29
+ else:
30
+ # uncertain whether flash-attention-2 is supported during inference phase.
31
+ self.attn_implementation = 'sdpa' # 'eager'
32
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
33
+
34
+ def load_model(self):
35
+ if self.is_loaded:
36
+ print('Vision tower is already loaded, `load model` call again, skipping.')
37
+ return
38
+
39
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
40
+
41
+ self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name,
42
+ attn_implementation=self.attn_implementation)
43
+
44
+ self.is_loaded = True
45
+
46
+ def feature_select(self, image_forward_outs):
47
+ image_features = image_forward_outs.hidden_states[self.select_layer]
48
+ if self.select_feature == 'patch':
49
+ image_features = image_features[:, 1:]
50
+ elif self.select_feature == 'cls_patch':
51
+ image_features = image_features
52
+ else:
53
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
54
+ return image_features
55
+
56
+ def forward(self, images, **kwargs):
57
+ images = torch.cat(images)
58
+ if type(images) is list:
59
+ image_features = []
60
+ for image in images:
61
+ image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
62
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
63
+ image_features.append(image_feature)
64
+ else:
65
+ image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
66
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
67
+
68
+ return image_features
69
+
70
+ @property
71
+ def dummy_feature(self):
72
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
73
+
74
+ @property
75
+ def dtype(self):
76
+ return self.vision_encoder.dtype
77
+
78
+ @property
79
+ def device(self):
80
+ return self.vision_encoder.device
81
+
82
+ @property
83
+ def config(self):
84
+ if self.is_loaded:
85
+ return self.vision_encoder.config
86
+ else:
87
+ return self.cfg_only
88
+
89
+ @property
90
+ def hidden_size(self):
91
+ return self.config.hidden_size
92
+
93
+ @property
94
+ def num_patches(self):
95
+ return (self.config.image_size // self.config.patch_size) ** 2
96
+
97
+ @property
98
+ def num_patches_per_side(self):
99
+ return self.config.image_size // self.config.patch_size
100
+
101
+ @property
102
+ def image_size(self):
103
+ return self.config.image_size
104
+
105
+
106
+ class SiglipVisionEncoder(nn.Module):
107
+
108
+ def __init__(self, vision_encoder, args, delay_load=False):
109
+ super().__init__()
110
+
111
+ self.is_loaded = False
112
+
113
+ self.vision_encoder_name = vision_encoder
114
+ self.select_layer = args.mm_vision_select_layer
115
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
116
+
117
+ if not delay_load:
118
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
119
+ self.load_model()
120
+ else:
121
+ # uncertain whether flash-attention-2 is supported during inference phase.
122
+ self.attn_implementation = 'sdpa' # 'eager'
123
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
124
+
125
+ def load_model(self):
126
+ if self.is_loaded:
127
+ print('Vision tower is already loaded, `load model` call again, skipping.')
128
+ return
129
+
130
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_encoder_name)
131
+
132
+ self.vision_encoder = SiglipVisionModel.from_pretrained(self.vision_encoder_name,
133
+ attn_implementation=self.attn_implementation)
134
+
135
+ self.is_loaded = True
136
+
137
+ def feature_select(self, image_forward_outs):
138
+ image_features = image_forward_outs.hidden_states[self.select_layer]
139
+ if self.select_feature == 'patch':
140
+ image_features = image_features
141
+ else:
142
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
143
+ return image_features
144
+
145
+ def forward(self, images, **kwargs):
146
+ images = torch.cat(images)
147
+ if type(images) is list:
148
+ image_features = []
149
+ for image in images:
150
+ image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
151
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
152
+ image_features.append(image_feature)
153
+ else:
154
+ image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
155
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
156
+
157
+ return image_features
158
+
159
+ @property
160
+ def dummy_feature(self):
161
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
162
+
163
+ @property
164
+ def dtype(self):
165
+ return self.vision_encoder.dtype
166
+
167
+ @property
168
+ def device(self):
169
+ return self.vision_encoder.device
170
+
171
+ @property
172
+ def config(self):
173
+ if self.is_loaded:
174
+ return self.vision_encoder.config
175
+ else:
176
+ return self.cfg_only
177
+
178
+ @property
179
+ def hidden_size(self):
180
+ return self.config.hidden_size
181
+
182
+ @property
183
+ def num_patches(self):
184
+ return (self.config.image_size // self.config.patch_size) ** 2
185
+
186
+ @property
187
+ def num_patches_per_side(self):
188
+ return self.config.image_size // self.config.patch_size
189
+
190
+ @property
191
+ def image_size(self):
192
+ return self.config.image_size
193
+
194
+
195
+ class Qwen2VLVisionEncoder(nn.Module):
196
+
197
+ def __init__(self, vision_encoder, args, delay_load=False):
198
+ super().__init__()
199
+
200
+ self.is_loaded = False
201
+
202
+ self.vision_encoder_name = vision_encoder
203
+ self.select_layer = args.mm_vision_select_layer
204
+
205
+ if not delay_load:
206
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
207
+ self.load_model(args)
208
+ else:
209
+ # uncertain whether flash-attention-2 is supported during inference phase.
210
+ self.attn_implementation = 'sdpa' # 'eager'
211
+ self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
212
+
213
+ def load_model(self, args):
214
+ if self.is_loaded:
215
+ print('Vision tower is already loaded, `load model` call again, skipping.')
216
+ return
217
+
218
+ # merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
219
+ # for stage 3, the merge_size is set to 2 by argments.
220
+ self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.vision_encoder_name)
221
+ self.image_processor.merge_size = args.spatial_merge_size
222
+ # NOTE: The maximum number of vision tokens is 8192 by default.
223
+ mm_max_length = args.mm_max_length if hasattr(args, 'mm_max_length') else 9477 // (args.spatial_merge_size**2)
224
+ self.image_processor.max_pixels = mm_max_length * (args.spatial_merge_size**2 * self.image_processor.patch_size**2)
225
+ self.image_processor.size["max_pixels"] = self.image_processor.max_pixels
226
+
227
+ # merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
228
+ self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
229
+ self.cfg_only.spatial_merge_size = args.spatial_merge_size
230
+
231
+ self.vision_encoder = Qwen2VisionTransformerPretrainedModel.from_pretrained(
232
+ self.vision_encoder_name,
233
+ config=self.cfg_only,
234
+ torch_dtype=args.torch_dtype,
235
+ attn_implementation=self.attn_implementation)
236
+
237
+ self.is_loaded = True
238
+
239
+ def forward(self, images, grid_thws, strides, **kwargs):
240
+ images = [image for sub_images in images for image in sub_images]
241
+ grid_thws = [grid_thw for sub_grid_thws in grid_thws for grid_thw in sub_grid_thws]
242
+ strides = [stride for sub_strides in strides for stride in sub_strides]
243
+
244
+ images = torch.cat(images, dim=0)
245
+ grid_thws = torch.cat(grid_thws, dim=0)
246
+
247
+ image_features = self.vision_encoder(images, grid_thws, strides=strides)
248
+
249
+ return image_features
250
+
251
+ @property
252
+ def dummy_feature(self):
253
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
254
+
255
+ @property
256
+ def dtype(self):
257
+ return self.vision_encoder.dtype
258
+
259
+ @property
260
+ def device(self):
261
+ return self.vision_encoder.device
262
+
263
+ @property
264
+ def config(self):
265
+ if self.is_loaded:
266
+ return self.vision_encoder.config
267
+ else:
268
+ return self.cfg_only
269
+
270
+ @property
271
+ def hidden_size(self):
272
+ return self.config.hidden_size
273
+
274
+ @property
275
+ def num_patches(self):
276
+ return -1
277
+
278
+ @property
279
+ def num_patches_per_side(self):
280
+ return -1
281
+
282
+ @property
283
+ def image_size(self):
284
+ return 14 * self.vision_encoder.config.spatial_merge_size
285
+
286
+
287
+ class DAMOVLVisionEncoder(nn.Module):
288
+
289
+ def __init__(self, vision_encoder, args, delay_load=False):
290
+ super().__init__()
291
+
292
+ self.is_loaded = False
293
+
294
+ self.vision_encoder_name = vision_encoder
295
+ self.args = args
296
+
297
+ if not delay_load:
298
+ self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
299
+ self.load_model(self.args)
300
+ else:
301
+ # uncertain whether flash-attention-2 is supported during inference phase.
302
+ self.attn_implementation = 'sdpa' # 'eager'
303
+ self.cfg_only = DAMOVLVisionConfig.from_pretrained(self.vision_encoder_name)
304
+
305
+ def load_model(self, args):
306
+ if self.is_loaded:
307
+ print('Vision tower is already loaded, `load model` call again, skipping.')
308
+ return
309
+
310
+ # merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
311
+ # for stage 3, the merge_size is set to 2 by argments.
312
+ self.image_processor = DAMOVLImageProcessor.from_pretrained(self.vision_encoder_name)
313
+ self.image_processor.merge_size = args.spatial_merge_size
314
+ # NOTE: The maximum number of vision tokens is 8192 by default.
315
+ mm_max_length = args.mm_max_length if hasattr(args, 'mm_max_length') else 9477 // (args.spatial_merge_size**2)
316
+ self.image_processor.max_pixels = mm_max_length * (args.spatial_merge_size**2 * self.image_processor.patch_size**2)
317
+ self.image_processor.size["max_pixels"] = self.image_processor.max_pixels
318
+
319
+ # merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
320
+ self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
321
+ self.cfg_only.spatial_merge_size = args.spatial_merge_size
322
+
323
+ self.vision_encoder = DAMOVLVisionModel.from_pretrained(
324
+ self.vision_encoder_name,
325
+ spatial_merge_size=args.spatial_merge_size,
326
+ torch_dtype=args.torch_dtype,
327
+ attn_implementation=self.attn_implementation)
328
+
329
+ self.is_loaded = True
330
+
331
+ def forward(self, images, grid_thws, strides, **kwargs):
332
+ images = [image for sub_images in images for image in sub_images]
333
+ grid_thws = [grid_thw for sub_grid_thws in grid_thws for grid_thw in sub_grid_thws]
334
+ strides = [stride for sub_strides in strides for stride in sub_strides]
335
+
336
+ images = torch.cat(images, dim=0)
337
+ grid_thws = torch.cat(grid_thws, dim=0)
338
+
339
+ image_features = self.vision_encoder(images, grid_thws, strides)
340
+
341
+ return image_features
342
+
343
+ @property
344
+ def dummy_feature(self):
345
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
346
+
347
+ @property
348
+ def dtype(self):
349
+ return self.vision_encoder.dtype
350
+
351
+ @property
352
+ def device(self):
353
+ return self.vision_encoder.device
354
+
355
+ @property
356
+ def config(self):
357
+ if self.is_loaded:
358
+ return self.vision_encoder.config
359
+ else:
360
+ return self.cfg_only
361
+
362
+ @property
363
+ def hidden_size(self):
364
+ return self.config.hidden_size
365
+
366
+ @property
367
+ def num_patches(self):
368
+ return -1
369
+
370
+ @property
371
+ def num_patches_per_side(self):
372
+ return -1
373
+
374
+ @property
375
+ def image_size(self):
376
+ return 14 * self.vision_encoder.config.spatial_merge_size
377
+
378
+
379
+ def build_vision_encoder(vision_encoder_cfg, **kwargs):
380
+
381
+ vision_encoder = getattr(vision_encoder_cfg, 'mm_vision_encoder', getattr(vision_encoder_cfg, 'vision_encoder', None))
382
+
383
+ vision_encoder = DAMOVLVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
384
+
385
+ return vision_encoder
videollama3/model/processor.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """
21
+ Processor class for VideoLLaMA3.
22
+ """
23
+ import copy
24
+ import math
25
+ import warnings
26
+ from typing import List, Union, Dict, Optional
27
+
28
+ import torch
29
+ from transformers.feature_extraction_utils import BatchFeature
30
+ from transformers.image_utils import ImageInput, VideoInput
31
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
32
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
33
+
34
+ import sys
35
+ sys.path.append(".")
36
+ from videollama3.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
37
+
38
+
39
+ DEFAULT_CHAT_TEMPLATE = """
40
+ {%- set identifier = 'im' %}
41
+ {% for message in messages %}
42
+ {% if message['role'] == 'stream' %}
43
+ {% set identifier = 'stream' %}
44
+ {% else %}
45
+ {% set identifier = 'im' %}
46
+ {% endif %}
47
+ {{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}
48
+ {% if message['content'] is string %}
49
+ {{- message['content'] + '<|' + identifier + '_end|>\n' -}}
50
+ {% else %}
51
+ {% for content in message['content'] %}
52
+ {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
53
+ {% if 'time' in content %}
54
+ {{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}
55
+ {% endif %}
56
+ """
57
+ DEFAULT_CHAT_TEMPLATE += """
58
+ {{- '%s\n' -}}
59
+ """ % DEFAULT_IMAGE_TOKEN
60
+ DEFAULT_CHAT_TEMPLATE += """
61
+ {% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}
62
+ {% for i in range(content['num_frames']) %}
63
+ {% if 'time' in content %}
64
+ {{- 'Time ' + content['time'][i] | round(1) | string + 's:' -}}
65
+ {% endif %}
66
+ {% if i < content['num_frames'] - 1 %}
67
+ """
68
+ DEFAULT_CHAT_TEMPLATE += """
69
+ {{- '%s,' -}}
70
+ """ % DEFAULT_IMAGE_TOKEN
71
+ DEFAULT_CHAT_TEMPLATE += """
72
+ {% else %}
73
+ """
74
+ DEFAULT_CHAT_TEMPLATE += """
75
+ {{- '%s\n' -}}
76
+ """ % DEFAULT_IMAGE_TOKEN
77
+ DEFAULT_CHAT_TEMPLATE += """
78
+ {% endif %}
79
+ {% endfor %}
80
+ {% elif 'text' in content %}
81
+ {{- content['text'] -}}
82
+ {% endif %}
83
+ {% endfor %}
84
+ {{- '<|' + identifier + '_end|>\n' -}}
85
+ {% endif %}
86
+ {% endfor %}
87
+ {% if add_generation_prompt %}
88
+ {{- '<|im_start|>assistant\n' -}}
89
+ {% endif %}
90
+ """
91
+
92
+
93
+ class Videollama3ProcessorKwargs(ProcessingKwargs, total=False):
94
+ _defaults = {
95
+ "text_kwargs": {
96
+ "padding": False,
97
+ },
98
+ }
99
+
100
+
101
+ class Videollama3Processor(ProcessorMixin):
102
+ r"""
103
+ Modified from Qwen2VLProcessor
104
+ Args:
105
+ image_processor ([`Qwen2VLImageProcessor`], *optional*):
106
+ The image processor is a required input.
107
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
108
+ The tokenizer is a required input.
109
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
110
+ in a chat into a tokenizable string.
111
+ """
112
+
113
+ attributes = ["image_processor", "tokenizer"]
114
+ valid_kwargs = ["chat_template"]
115
+ image_processor_class = "Qwen2VLImageProcessor"
116
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
117
+
118
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
119
+ if chat_template is None:
120
+ chat_template = DEFAULT_CHAT_TEMPLATE
121
+ # super().__init__(image_processor, tokenizer, chat_template=chat_template)
122
+ tokenizer.chat_template = chat_template
123
+ self.image_processor = image_processor
124
+ self.tokenizer = tokenizer
125
+ self.generation_prompt = self._infer_generation_prompt()
126
+ self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
127
+ self.generation_prompt_length = len(self.generation_prompt_ids[0])
128
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
129
+ self.eos_token_id = self.tokenizer.eos_token_id
130
+
131
+ def get_generation_prompt(self):
132
+ return self.generation_prompt
133
+
134
+ def get_generation_prompt_ids(self):
135
+ return self.generation_prompt_ids
136
+
137
+ def _infer_generation_prompt(self):
138
+ pseudo_message = [{"role": "user", "content": ""}]
139
+ instruction = self.tokenizer.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
140
+ conversation = self.tokenizer.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
141
+ return instruction.replace(conversation, "")
142
+
143
+ def _process_text_with_label(
144
+ self,
145
+ text: List[Dict],
146
+ image_grid_thw: torch.Tensor = None,
147
+ image_downsampling: Optional[int] = None,
148
+ **kwargs,
149
+ ):
150
+ assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
151
+ assert isinstance(text[0], dict), "When return_labels=True, text must be a list of messages."
152
+
153
+ input_ids_list = []
154
+ targets_list = []
155
+ sample_types_list = []
156
+ image_idx = 0
157
+
158
+ for message_idx, message in enumerate(text):
159
+ # 1. set chat template and append image tokens
160
+ prompt = self.tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=False)
161
+ prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
162
+ prompt = []
163
+ for chunk_idx in range(len(prompt_chunks) - 1):
164
+ prompt.append(prompt_chunks[chunk_idx])
165
+ thw = image_grid_thw[image_idx]
166
+ prompt.append(DEFAULT_IMAGE_TOKEN * (thw.prod() / image_downsampling**2).long())
167
+ image_idx += 1
168
+ prompt.append(prompt_chunks[-1])
169
+ prompt = "".join(prompt)
170
+
171
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")[0]
172
+ input_ids_list.append(input_ids)
173
+
174
+ targets = torch.full_like(input_ids, IGNORE_INDEX)
175
+ sample_types = torch.full_like(input_ids, IGNORE_INDEX)
176
+ if message["role"] == "assistant":
177
+ targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
178
+ elif message["role"] == "stream":
179
+ diff = torch.diff((input_ids == self.image_token_id).float())
180
+ image_end_indices = torch.nonzero(diff < 0)[:, 0]
181
+ targets[image_end_indices + 1] = input_ids[image_end_indices + 1]
182
+ sample_types = targets.clone()
183
+ sample_types[torch.logical_and(sample_types > 0, sample_types != self.eos_token_id)] = 0
184
+ targets[-2] = input_ids[-2] # <|im_end|>
185
+
186
+ # if message_idx > 0 and text[message_idx - 1]["role"] == "stream":
187
+ # targets[0] = input_ids[0]
188
+ # # TODO: consider non-special tokens
189
+ # sample_types[0] = input_ids[0]
190
+
191
+ targets_list.append(targets)
192
+ sample_types_list.append(sample_types)
193
+
194
+ assert len(image_grid_thw) == image_idx, "Number of images does not match the number of image tokens in the text."
195
+
196
+ targets = torch.cat(targets_list)
197
+ sample_types = torch.cat(sample_types_list)
198
+ types, counts = torch.unique(sample_types[sample_types > -1], return_counts=True)
199
+
200
+ if len(types) > 0:
201
+ target_num_samples = counts.amin()
202
+
203
+ for type_id, type_count in zip(types, counts):
204
+ if type_count > target_num_samples:
205
+ indices = torch.nonzero(sample_types == type_id)[:, 0]
206
+ random_selector = torch.randperm(indices.size(0))[:-target_num_samples]
207
+ targets[indices[random_selector]] = IGNORE_INDEX
208
+ sample_types[indices[random_selector]] = -1
209
+
210
+ text_inputs = {
211
+ "input_ids": torch.cat(input_ids_list),
212
+ "labels": targets,
213
+ }
214
+
215
+ return text_inputs
216
+
217
+ def _process_text_without_label(
218
+ self,
219
+ text: Union[List[str], List[Dict]],
220
+ image_grid_thw: torch.Tensor = None,
221
+ image_downsampling: Optional[int] = None,
222
+ **kwargs,
223
+ ):
224
+ if isinstance(text[0], dict):
225
+ warnings.warn("Input text is a list of messages. Automatically convert it to a string with 'apply_chat_template' with generation prompt.")
226
+ text = [self.tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)]
227
+
228
+ image_idx = 0
229
+ for i in range(len(text)):
230
+ while DEFAULT_IMAGE_TOKEN in text[i]:
231
+ thw = image_grid_thw[image_idx]
232
+ text[i] = text[i].replace(DEFAULT_IMAGE_TOKEN, "<placeholder>" * (thw.prod() / image_downsampling**2).long(), 1)
233
+ image_idx += 1
234
+ text[i] = text[i].replace("<placeholder>", DEFAULT_IMAGE_TOKEN)
235
+ assert len(image_grid_thw) == image_idx, "Number of images does not match the number of image tokens in the text."
236
+
237
+ text_inputs = self.tokenizer(text, **kwargs)
238
+ return text_inputs
239
+
240
+ def _process_text(
241
+ self,
242
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]],
243
+ image_grid_thw: torch.Tensor = None,
244
+ image_downsampling: Optional[int] = None,
245
+ return_labels: bool = False,
246
+ **kwargs,
247
+ ):
248
+ if not isinstance(text, (list, tuple)):
249
+ text = [text]
250
+ assert len(text), "At least one text must be provided."
251
+
252
+ if return_labels:
253
+ return self._process_text_with_label(text, image_grid_thw, image_downsampling, **kwargs)
254
+ return self._process_text_without_label(text, image_grid_thw, image_downsampling, **kwargs)
255
+
256
+ def _process_image(
257
+ self,
258
+ images: ImageInput = None,
259
+ image_downsampling: Optional[int] = None,
260
+ **kwargs,
261
+ ):
262
+ if image_downsampling is None:
263
+ image_downsampling = self.image_processor.merge_size
264
+
265
+ image_inputs = {
266
+ "images": [],
267
+ "grid_thws": [],
268
+ "image_downsampling": image_downsampling
269
+ }
270
+ if images is not None and len(images) > 0:
271
+ num_images = kwargs.get('num_images', len(images))
272
+ if 'num_images' in kwargs:
273
+ kwargs.pop('num_images')
274
+ for image in images:
275
+ outputs = self.image_processor(images=image, num_images=num_images, image_downsampling=image_downsampling, **kwargs)
276
+ # images shapes like: [tensor([patches, 1176]), ...]
277
+ # grid_thws shapes like: tensor([num_images, 3])
278
+
279
+ # flatten_patches1 = outputs["pixel_values"].reshape(26, 46, 3, -1)
280
+ # from matplotlib import pyplot as plt
281
+ # plt.imshow(flatten_patches1[:,:,:,0])
282
+ # plt.savefig('9.png')
283
+
284
+ image_inputs["images"].append(outputs["pixel_values"]) #正常的
285
+
286
+ # flatten_patches1 = image_inputs["images"][0].reshape(26, 46, 3, -1)
287
+ # from matplotlib import pyplot as plt
288
+ # plt.imshow(flatten_patches1[:,:,:,0])
289
+ # plt.savefig('12.png')
290
+ image_inputs["grid_thws"].append(outputs["image_grid_thw"])
291
+
292
+ return image_inputs
293
+
294
+
295
+
296
+ def __call__(
297
+ self,
298
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]] = None,
299
+ images: ImageInput = None,
300
+ image_downsampling: Optional[int] = None,
301
+ return_labels: bool = False,
302
+ **kwargs: Unpack[Videollama3ProcessorKwargs],
303
+ ) -> BatchFeature:
304
+ """
305
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
306
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
307
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
308
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
309
+
310
+ Args:
311
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
312
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
313
+ tensor. Both channels-first and channels-last formats are supported.
314
+ text (`str`, `List[str]`, `List[List[str]]`):
315
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
316
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
317
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
318
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
319
+ If set, will return tensors of a particular framework. Acceptable values are:
320
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
321
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
322
+ - `'np'`: Return NumPy `np.ndarray` objects.
323
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
324
+
325
+ Returns:
326
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
327
+
328
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
329
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
330
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
331
+ `None`).
332
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
333
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
334
+ """
335
+ output_kwargs = self._merge_kwargs(
336
+ Videollama3ProcessorKwargs,
337
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
338
+ **kwargs,
339
+ )
340
+ output_kwargs["text_kwargs"].pop("padding")
341
+ output_kwargs["text_kwargs"].pop("padding_side")
342
+
343
+ image_inputs = self._process_image(images, image_downsampling, **output_kwargs["images_kwargs"])
344
+ text_inputs = self._process_text(text, image_inputs["grid_thws"], image_downsampling, return_labels, **output_kwargs["text_kwargs"])
345
+
346
+ return BatchFeature(data={**text_inputs, **image_inputs})
347
+
348
+ def batch_decode(self, *args, **kwargs):
349
+ """
350
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
351
+ refer to the docstring of this method for more information.
352
+ """
353
+ return self.tokenizer.batch_decode(*args, **kwargs)
354
+
355
+ def decode(self, *args, **kwargs):
356
+ """
357
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
358
+ the docstring of this method for more information.
359
+ """
360
+ return self.tokenizer.decode(*args, **kwargs)
361
+
362
+ @property
363
+ def model_input_names(self):
364
+ tokenizer_input_names = self.tokenizer.model_input_names
365
+ image_processor_input_names = self.image_processor.model_input_names
366
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
videollama3/model/projector.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Alibaba DAMO Academy
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import os
17
+ import re
18
+
19
+ import einops
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from timm.models.layers import LayerNorm, LayerNorm2d
24
+ from timm.models.regnet import RegStage
25
+ from transformers import TRANSFORMERS_CACHE
26
+
27
+
28
+ def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"):
29
+ revision = "main"
30
+ # 1. parse the downloaded cache folder
31
+ if cache_dir is None:
32
+ cache_dir = TRANSFORMERS_CACHE
33
+ else:
34
+ cache_dir = cache_dir
35
+ object_id = repo_id.replace("/", "--")
36
+ repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
37
+ # 2. resolve refs (for instance to convert main to the associated commit sha)
38
+ refs_dir = os.path.join(repo_cache, "refs")
39
+ if os.path.isdir(refs_dir):
40
+ revision_file = os.path.join(refs_dir, revision)
41
+ if os.path.isfile(revision_file):
42
+ with open(revision_file) as f:
43
+ revision = f.read()
44
+ # 3. acquire the snapshot folder
45
+ folder = os.path.join(repo_cache, "snapshots", revision)
46
+
47
+ return folder
48
+
49
+
50
+ def load_mm_projector(model_path, cache_dir=None, token=None):
51
+ if os.path.exists(os.path.join(model_path, 'mm_projector.bin')):
52
+ is_local = True
53
+ folder = model_path
54
+ else:
55
+ is_local = False
56
+ folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model")
57
+ if not os.path.exists(os.path.join(folder, 'mm_projector.bin')):
58
+ # downloading from remote repo
59
+ from huggingface_hub import snapshot_download
60
+ snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token)
61
+
62
+ mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu')
63
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
64
+ return mm_projector_weights
65
+
66
+
67
+ class IdentityMap(nn.Module):
68
+
69
+ def __init__(self):
70
+ super().__init__()
71
+
72
+ def forward(self, x, *args, **kwargs):
73
+ return x
74
+
75
+ @property
76
+ def config(self):
77
+ return {"mm_projector_type": 'identity'}
78
+
79
+
80
+ def build_mlp(depth, hidden_size, output_hidden_size):
81
+ modules = [nn.Linear(hidden_size, output_hidden_size)]
82
+ for _ in range(1, depth):
83
+ modules.append(nn.GELU())
84
+ modules.append(nn.Linear(output_hidden_size, output_hidden_size))
85
+ return nn.Sequential(*modules)
86
+
87
+
88
+ class SimSpatialConv(nn.Module):
89
+
90
+ def __init__(self, config, downsample=(2, 2), padding=1, depth=1, mlp_depth=2):
91
+ super().__init__()
92
+ self.encoder_hidden_size = encoder_hidden_size = config.mm_hidden_size
93
+ self.output_hidden_size = output_hidden_size = config.hidden_size
94
+ self.downsample = downsample
95
+ self.padding = padding
96
+ self.sampler = nn.Sequential(
97
+ nn.Conv2d(
98
+ in_channels=self.encoder_hidden_size,
99
+ out_channels=4 * self.encoder_hidden_size,
100
+ kernel_size=self.downsample,
101
+ stride=self.downsample,
102
+ padding=self.padding,
103
+ bias=True
104
+ ),
105
+ nn.SiLU(),
106
+ )
107
+ self.readout = build_mlp(mlp_depth, 4 * self.encoder_hidden_size, self.output_hidden_size)
108
+
109
+ def forward(self, x):
110
+ hw = int(x.size(1) ** 0.5)
111
+ x = einops.rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
112
+ x = self.sampler(x)
113
+ x = einops.rearrange(x, "b d h w -> b (h w) d")
114
+ x = self.readout(x)
115
+ return x
116
+
117
+ def cal_proj_size(self, input_size):
118
+ if isinstance(input_size, int):
119
+ input_size = (input_size, input_size)
120
+ height = math.ceil((input_size[0] + self.padding) / self.downsample[0])
121
+ width = math.ceil((input_size[1] + self.padding) / self.downsample[1])
122
+ return height * width
123
+
124
+
125
+ class MlpGeluProjector(nn.Module):
126
+ def __init__(self, config, projector_type):
127
+ super().__init__()
128
+
129
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
130
+ mlp_depth = int(mlp_gelu_match.group(1))
131
+
132
+ self.readout = build_mlp(mlp_depth, config.mm_hidden_size, config.hidden_size)
133
+
134
+ def forward(self, x):
135
+ x = self.readout(x)
136
+ return x
137
+
138
+ def cal_proj_size(self, input_size):
139
+ if isinstance(input_size, int):
140
+ input_size = (input_size, input_size)
141
+ height = input_size[0]
142
+ width = input_size[1]
143
+ return height * width
144
+
145
+
146
+ def build_vision_projector(config, delay_load=False, **kwargs):
147
+ # videollama3 projector only support image-wise operation now, i.e., prohibit the temporal aggregation
148
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
149
+
150
+ if projector_type == "linear":
151
+ # NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
152
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
153
+ elif projector_type == "simp_spatial_conv":
154
+ return SimSpatialConv(config)
155
+ elif projector_type.startswith("mlp"):
156
+ return MlpGeluProjector(config, projector_type)
157
+ if projector_type == 'identity':
158
+ return IdentityMap()
159
+
160
+ raise ValueError(f'Unknown projector type: {projector_type}')
videollama3/model/qwen2vl_encoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig
2
+ from .image_processing import Qwen2VLImageProcessor
3
+ from .modeling_qwen2vl_encoder import Qwen2VisionTransformerPretrainedModel
videollama3/model/qwen2vl_encoder/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (432 Bytes). View file
 
videollama3/model/qwen2vl_encoder/__pycache__/configuration_qwen2vl_encoder.cpython-310.pyc ADDED
Binary file (1.92 kB). View file
 
videollama3/model/qwen2vl_encoder/__pycache__/image_processing.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
videollama3/model/qwen2vl_encoder/__pycache__/modeling_qwen2vl_encoder.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
videollama3/model/qwen2vl_encoder/configuration_qwen2vl_encoder.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2VL model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class Qwen2VLVisionConfig(PretrainedConfig):
28
+ model_type = "qwen2_vl"
29
+
30
+ def __init__(
31
+ self,
32
+ depth=32,
33
+ embed_dim=1280,
34
+ hidden_size=3584,
35
+ hidden_act="quick_gelu",
36
+ mlp_ratio=4,
37
+ num_heads=16,
38
+ in_channels=3,
39
+ patch_size=14,
40
+ spatial_merge_size=2,
41
+ temporal_patch_size=2,
42
+ **kwargs,
43
+ ):
44
+ super().__init__(**kwargs)
45
+
46
+ self.depth = depth
47
+ self.embed_dim = embed_dim
48
+ self.hidden_size = hidden_size
49
+ self.hidden_act = hidden_act
50
+ self.mlp_ratio = mlp_ratio
51
+ self.num_heads = num_heads
52
+ self.in_channels = in_channels
53
+ self.patch_size = patch_size
54
+ self.spatial_merge_size = spatial_merge_size
55
+ self.temporal_patch_size = temporal_patch_size
56
+
57
+ @classmethod
58
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
59
+ cls._set_token_in_kwargs(kwargs)
60
+
61
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
62
+
63
+ # if config_dict.get("model_type") == "qwen2_vl":
64
+ # config_dict = config_dict["vision_config"]
65
+
66
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
67
+ logger.warning(
68
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
69
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
70
+ )
71
+
72
+ return cls.from_dict(config_dict, **kwargs)
videollama3/model/qwen2vl_encoder/image_processing.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Image processor class for Qwen2-VL."""
21
+
22
+ import math
23
+ from typing import Dict, List, Optional, Union
24
+
25
+ import numpy as np
26
+
27
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
28
+ from transformers.image_transforms import (
29
+ convert_to_rgb,
30
+ resize,
31
+ to_channel_dimension_format,
32
+ )
33
+ from transformers.image_utils import (
34
+ OPENAI_CLIP_MEAN,
35
+ OPENAI_CLIP_STD,
36
+ ChannelDimension,
37
+ ImageInput,
38
+ PILImageResampling,
39
+ VideoInput,
40
+ get_image_size,
41
+ infer_channel_dimension_format,
42
+ is_scaled_image,
43
+ is_valid_image,
44
+ make_list_of_images,
45
+ to_numpy_array,
46
+ valid_images,
47
+ validate_preprocess_arguments,
48
+ )
49
+ from transformers.utils import TensorType, is_vision_available, logging
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+
55
+ if is_vision_available():
56
+ from PIL import Image
57
+
58
+
59
+ def make_batched_images(images) -> List[List[ImageInput]]:
60
+ """
61
+ Accepts images in list or nested list format, and makes a list of images for preprocessing.
62
+
63
+ Args:
64
+ images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
65
+ The input image.
66
+
67
+ Returns:
68
+ list: A list of images.
69
+ """
70
+ if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
71
+ return [img for img_list in images for img in img_list]
72
+
73
+ elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
74
+ return images
75
+
76
+ elif is_valid_image(images):
77
+ return [images]
78
+
79
+ raise ValueError(f"Could not make batched images from {images}")
80
+
81
+
82
+ # Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
83
+ def make_batched_videos(videos) -> List[VideoInput]:
84
+ if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
85
+ return videos
86
+
87
+ elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
88
+ if isinstance(videos[0], Image.Image):
89
+ return [videos]
90
+ elif len(videos[0].shape) == 4:
91
+ return [list(video) for video in videos]
92
+
93
+ elif is_valid_image(videos) and len(videos.shape) == 4:
94
+ return [list(videos)]
95
+
96
+ raise ValueError(f"Could not make batched video from {videos}")
97
+
98
+
99
+ def smart_resize(
100
+ height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
101
+ ):
102
+ """Rescales the image so that the following conditions are met:
103
+
104
+ 1. Both dimensions (height and width) are divisible by 'factor'.
105
+
106
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
107
+
108
+ 3. The aspect ratio of the image is maintained as closely as possible.
109
+
110
+ """
111
+ if height < factor or width < factor:
112
+ scale = factor / min(height, width)
113
+ width = round(scale * width)
114
+ height = round(scale * height)
115
+ elif max(height, width) / min(height, width) > 200:
116
+ raise ValueError(
117
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
118
+ )
119
+ h_bar = round(height / factor) * factor
120
+ w_bar = round(width / factor) * factor
121
+ if h_bar * w_bar > max_pixels:
122
+ beta = math.sqrt((height * width) / max_pixels)
123
+ h_bar = math.floor(height / beta / factor) * factor
124
+ w_bar = math.floor(width / beta / factor) * factor
125
+ elif h_bar * w_bar < min_pixels:
126
+ beta = math.sqrt(min_pixels / (height * width))
127
+ h_bar = math.ceil(height * beta / factor) * factor
128
+ w_bar = math.ceil(width * beta / factor) * factor
129
+ return h_bar, w_bar
130
+
131
+
132
+ class Qwen2VLImageProcessor(BaseImageProcessor):
133
+ r"""
134
+ Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
135
+
136
+ Args:
137
+ do_resize (`bool`, *optional*, defaults to `True`):
138
+ Whether to resize the image's (height, width) dimensions.
139
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
140
+ Resampling filter to use when resizing the image.
141
+ do_rescale (`bool`, *optional*, defaults to `True`):
142
+ Whether to rescale the image by the specified scale `rescale_factor`.
143
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
144
+ Scale factor to use if rescaling the image.
145
+ do_normalize (`bool`, *optional*, defaults to `True`):
146
+ Whether to normalize the image.
147
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
148
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
149
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
150
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
151
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
152
+ Whether to convert the image to RGB.
153
+ min_pixels (`int`, *optional*, defaults to `56 * 56`):
154
+ The min pixels of the image to resize the image.
155
+ max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
156
+ The max pixels of the image to resize the image.
157
+ patch_size (`int`, *optional*, defaults to 14):
158
+ The spacial patch size of the vision encoder.
159
+ temporal_patch_size (`int`, *optional*, defaults to 2):
160
+ The temporal patch size of the vision encoder.
161
+ merge_size (`int`, *optional*, defaults to 2):
162
+ The merge size of the vision encoder to llm encoder.
163
+ """
164
+
165
+ model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
166
+
167
+ def __init__(
168
+ self,
169
+ do_resize: bool = True,
170
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
171
+ do_rescale: bool = True,
172
+ rescale_factor: Union[int, float] = 1 / 255,
173
+ do_normalize: bool = True,
174
+ image_mean: Optional[Union[float, List[float]]] = None,
175
+ image_std: Optional[Union[float, List[float]]] = None,
176
+ do_convert_rgb: bool = True,
177
+ min_pixels: int = 56 * 56,
178
+ max_pixels: int = 28 * 28 * 1280,
179
+ patch_size: int = 14,
180
+ temporal_patch_size: int = 2,
181
+ merge_size: int = 2,
182
+ **kwargs,
183
+ ) -> None:
184
+ super().__init__(**kwargs)
185
+ self.do_resize = do_resize
186
+ self.resample = resample
187
+ self.do_rescale = do_rescale
188
+ self.rescale_factor = rescale_factor
189
+ self.do_normalize = do_normalize
190
+ self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
191
+ self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
192
+ self.min_pixels = min_pixels
193
+ self.max_pixels = max_pixels
194
+ self.patch_size = patch_size
195
+ self.temporal_patch_size = temporal_patch_size
196
+ self.merge_size = merge_size
197
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
198
+ self.do_convert_rgb = do_convert_rgb
199
+
200
+ def _preprocess(
201
+ self,
202
+ images: Union[ImageInput, VideoInput],
203
+ do_resize: bool = None,
204
+ resample: PILImageResampling = None,
205
+ do_rescale: bool = None,
206
+ rescale_factor: float = None,
207
+ do_normalize: bool = None,
208
+ image_mean: Optional[Union[float, List[float]]] = None,
209
+ image_std: Optional[Union[float, List[float]]] = None,
210
+ do_convert_rgb: bool = None,
211
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
212
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
213
+ num_images: Optional[int] = 1,
214
+ image_downsampling: Optional[int] = None,
215
+ ):
216
+ """
217
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
218
+
219
+ Args:
220
+ images (`ImageInput`):
221
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
222
+ vision_info (`List[Dict]`, *optional*):
223
+ Optional list of dictionaries containing additional information about vision inputs.
224
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
225
+ Whether to resize the image.
226
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
227
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
228
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
229
+ Whether to rescale the image.
230
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
231
+ Scale factor to use if rescaling the image.
232
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
233
+ Whether to normalize the image.
234
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
235
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
236
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
237
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
238
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
239
+ Whether to convert the image to RGB.
240
+ data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
241
+ The channel dimension format for the output image. Can be one of:
242
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
243
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
244
+ - Unset: Use the channel dimension format of the input image.
245
+ input_data_format (`ChannelDimension` or `str`, *optional*):
246
+ The channel dimension format for the input image. Can be one of:
247
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
248
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
249
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
250
+ """
251
+ images = make_list_of_images(images)
252
+
253
+ if do_convert_rgb:
254
+ images = [convert_to_rgb(image) for image in images]
255
+
256
+ # All transformations expect numpy arrays.
257
+ images = [to_numpy_array(image) for image in images]
258
+
259
+ if is_scaled_image(images[0]) and do_rescale:
260
+ logger.warning_once(
261
+ "It looks like you are trying to rescale already rescaled images. If the input"
262
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
263
+ )
264
+ if input_data_format is None:
265
+ # We assume that all images have the same channel dimension format.
266
+ input_data_format = infer_channel_dimension_format(images[0])
267
+
268
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
269
+ resized_height, resized_width = height, width
270
+ processed_images = []
271
+ for image in images:
272
+ if do_resize:
273
+ max_pixels = int(self.max_pixels / (self.merge_size / image_downsampling)**2)
274
+ resized_height, resized_width = smart_resize(
275
+ height,
276
+ width,
277
+ factor=self.patch_size * image_downsampling,
278
+ min_pixels=self.min_pixels,
279
+ max_pixels=int(max_pixels // num_images),
280
+ )
281
+ image = resize(
282
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
283
+ )
284
+
285
+ if do_rescale:
286
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
287
+
288
+ if do_normalize:
289
+ image = self.normalize(
290
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
291
+ )
292
+
293
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
294
+ processed_images.append(image)
295
+
296
+ patches = np.array(processed_images)
297
+ if data_format == ChannelDimension.LAST:
298
+ patches = patches.transpose(0, 3, 1, 2)
299
+ if patches.shape[0] == 1:
300
+ patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
301
+ channel = patches.shape[1]
302
+ grid_t = patches.shape[0] // self.temporal_patch_size
303
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
304
+ patches = patches.reshape(
305
+ grid_t,
306
+ self.temporal_patch_size,
307
+ channel,
308
+ grid_h // image_downsampling,
309
+ image_downsampling,
310
+ self.patch_size,
311
+ grid_w // image_downsampling,
312
+ image_downsampling,
313
+ self.patch_size,
314
+ )
315
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
316
+ flatten_patches = patches.reshape(
317
+ grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
318
+ )
319
+ return flatten_patches, (grid_t, grid_h, grid_w)
320
+
321
+ def preprocess(
322
+ self,
323
+ images: ImageInput,
324
+ videos: VideoInput = None,
325
+ do_resize: bool = None,
326
+ size: Dict[str, int] = None,
327
+ resample: PILImageResampling = None,
328
+ do_rescale: bool = None,
329
+ rescale_factor: float = None,
330
+ do_normalize: bool = None,
331
+ image_mean: Optional[Union[float, List[float]]] = None,
332
+ image_std: Optional[Union[float, List[float]]] = None,
333
+ do_convert_rgb: bool = None,
334
+ return_tensors: Optional[Union[str, TensorType]] = None,
335
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
336
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
337
+ num_images: Optional[int] = 1,
338
+ image_downsampling: Optional[int] = None,
339
+ ):
340
+ """
341
+ Args:
342
+ images (`ImageInput`):
343
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
344
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
345
+ videos (`VideoInput`):
346
+ Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
347
+ passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
348
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
349
+ Whether to resize the image.
350
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
351
+ Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
352
+ the longest edge resized to keep the input aspect ratio.
353
+ resample (`int`, *optional*, defaults to `self.resample`):
354
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
355
+ has an effect if `do_resize` is set to `True`.
356
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
357
+ Whether to rescale the image.
358
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
359
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
360
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
361
+ Whether to normalize the image.
362
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
363
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
364
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
365
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
366
+ `True`.
367
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
368
+ Whether to convert the image to RGB.
369
+ return_tensors (`str` or `TensorType`, *optional*):
370
+ The type of tensors to return. Can be one of:
371
+ - Unset: Return a list of `np.ndarray`.
372
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
373
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
374
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
375
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
376
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
377
+ The channel dimension format for the output image. Can be one of:
378
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
379
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
380
+ - Unset: Use the channel dimension format of the input image.
381
+ input_data_format (`ChannelDimension` or `str`, *optional*):
382
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
383
+ from the input image. Can be one of:
384
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
385
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
386
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
387
+
388
+ """
389
+ do_resize = do_resize if do_resize is not None else self.do_resize
390
+ size = size if size is not None else self.size
391
+ resample = resample if resample is not None else self.resample
392
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
393
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
394
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
395
+ image_mean = image_mean if image_mean is not None else self.image_mean
396
+ image_std = image_std if image_std is not None else self.image_std
397
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
398
+ image_downsampling = image_downsampling if image_downsampling is not None else self.merge_size
399
+
400
+ if images is not None:
401
+ images = make_batched_images(images)
402
+ if videos is not None:
403
+ videos = make_batched_videos(videos)
404
+
405
+ if images is not None and not valid_images(images):
406
+ raise ValueError(
407
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
408
+ "torch.Tensor, tf.Tensor or jax.ndarray."
409
+ )
410
+
411
+ validate_preprocess_arguments(
412
+ rescale_factor=rescale_factor,
413
+ do_normalize=do_normalize,
414
+ image_mean=image_mean,
415
+ image_std=image_std,
416
+ do_resize=do_resize,
417
+ size=size,
418
+ resample=resample,
419
+ )
420
+
421
+ if images is not None:
422
+ pixel_values, vision_grid_thws = [], []
423
+ for image in images:
424
+ patches, image_grid_thw = self._preprocess(
425
+ image,
426
+ do_resize=do_resize,
427
+ resample=resample,
428
+ do_rescale=do_rescale,
429
+ rescale_factor=rescale_factor,
430
+ do_normalize=do_normalize,
431
+ image_mean=image_mean,
432
+ image_std=image_std,
433
+ data_format=data_format,
434
+ do_convert_rgb=do_convert_rgb,
435
+ input_data_format=input_data_format,
436
+ num_images=num_images,
437
+ image_downsampling=image_downsampling,
438
+ )
439
+ pixel_values.extend(patches)
440
+ vision_grid_thws.append(image_grid_thw)
441
+ pixel_values = np.array(pixel_values)
442
+ vision_grid_thws = np.array(vision_grid_thws)
443
+ data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
444
+
445
+ if videos is not None:
446
+ pixel_values, vision_grid_thws = [], []
447
+ for images in videos:
448
+ patches, video_grid_thw = self._preprocess(
449
+ images,
450
+ do_resize=do_resize,
451
+ resample=resample,
452
+ do_rescale=do_rescale,
453
+ rescale_factor=rescale_factor,
454
+ do_normalize=do_normalize,
455
+ image_mean=image_mean,
456
+ image_std=image_std,
457
+ data_format=data_format,
458
+ do_convert_rgb=do_convert_rgb,
459
+ input_data_format=input_data_format,
460
+ num_images=num_images,
461
+ image_downsampling=image_downsampling,
462
+ )
463
+ pixel_values.extend(patches)
464
+ vision_grid_thws.append(video_grid_thw)
465
+ pixel_values = np.array(pixel_values)
466
+ vision_grid_thws = np.array(vision_grid_thws)
467
+ data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
468
+
469
+ return BatchFeature(data=data, tensor_type=return_tensors)
videollama3/model/qwen2vl_encoder/modeling_qwen2vl_encoder.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen2-VL model."""
21
+
22
+ import math
23
+ from dataclasses import dataclass
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch.utils.checkpoint
30
+ from torch.nn import CrossEntropyLoss, LayerNorm
31
+ from transformers.activations import ACT2FN
32
+ from transformers.cache_utils import Cache, StaticCache
33
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (add_start_docstrings,
37
+ add_start_docstrings_to_model_forward,
38
+ is_flash_attn_2_available,
39
+ is_flash_attn_greater_or_equal_2_10, logging,
40
+ replace_return_docstrings)
41
+
42
+ from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig
43
+
44
+ if is_flash_attn_2_available():
45
+ from flash_attn import flash_attn_varlen_func
46
+ from transformers.modeling_flash_attention_utils import \
47
+ _flash_attention_forward
48
+ else:
49
+ flash_attn_varlen_func = None
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
55
+ def rotate_half(x):
56
+ """Rotates half the hidden dims of the input."""
57
+ x1 = x[..., : x.shape[-1] // 2]
58
+ x2 = x[..., x.shape[-1] // 2 :]
59
+ return torch.cat((-x2, x1), dim=-1)
60
+
61
+
62
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
63
+ orig_dtype = tensor.dtype
64
+ tensor = tensor.float()
65
+ cos = freqs.cos()
66
+ sin = freqs.sin()
67
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
68
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
69
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
70
+ output = output.to(orig_dtype)
71
+ return output
72
+
73
+
74
+ class VisionRotaryEmbedding(nn.Module):
75
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
76
+ super().__init__()
77
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
78
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
79
+
80
+ def forward(self, seqlen: int) -> torch.Tensor:
81
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
82
+ freqs = torch.outer(seq, self.inv_freq)
83
+ return freqs
84
+
85
+
86
+ class PatchEmbed(nn.Module):
87
+ def __init__(
88
+ self,
89
+ patch_size: int = 14,
90
+ temporal_patch_size: int = 2,
91
+ in_channels: int = 3,
92
+ embed_dim: int = 1152,
93
+ ) -> None:
94
+ super().__init__()
95
+ self.patch_size = patch_size
96
+ self.temporal_patch_size = temporal_patch_size
97
+ self.in_channels = in_channels
98
+ self.embed_dim = embed_dim
99
+
100
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
101
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
102
+
103
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
104
+ target_dtype = self.proj.weight.dtype
105
+ hidden_states = hidden_states.view(
106
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
107
+ )
108
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
109
+ return hidden_states
110
+
111
+
112
+ class PatchMerger(nn.Module):
113
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
114
+ super().__init__()
115
+ self.hidden_size = context_dim * (spatial_merge_size**2)
116
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
117
+ self.mlp = nn.Sequential(
118
+ nn.Linear(self.hidden_size, self.hidden_size),
119
+ nn.GELU(),
120
+ nn.Linear(self.hidden_size, dim),
121
+ )
122
+
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
125
+ return x
126
+
127
+
128
+ class VisionMlp(nn.Module):
129
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
130
+ super().__init__()
131
+ self.fc1 = nn.Linear(dim, hidden_dim)
132
+ self.act = ACT2FN[hidden_act]
133
+ self.fc2 = nn.Linear(hidden_dim, dim)
134
+
135
+ def forward(self, x) -> torch.Tensor:
136
+ return self.fc2(self.act(self.fc1(x)))
137
+
138
+
139
+ class VisionAttention(nn.Module):
140
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
141
+ super().__init__()
142
+ self.num_heads = num_heads
143
+ self.head_dim = dim // num_heads
144
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
145
+ self.proj = nn.Linear(dim, dim)
146
+
147
+ def forward(
148
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
149
+ ) -> torch.Tensor:
150
+ seq_length = hidden_states.shape[0]
151
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
152
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
153
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
154
+
155
+ attention_mask = torch.full(
156
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
157
+ )
158
+ for i in range(1, len(cu_seqlens)):
159
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
160
+
161
+ q = q.transpose(0, 1)
162
+ k = k.transpose(0, 1)
163
+ v = v.transpose(0, 1)
164
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
165
+ attn_weights = attn_weights + attention_mask
166
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
167
+ attn_output = torch.matmul(attn_weights, v)
168
+ attn_output = attn_output.transpose(0, 1)
169
+ attn_output = attn_output.reshape(seq_length, -1)
170
+ attn_output = self.proj(attn_output)
171
+ return attn_output
172
+
173
+
174
+ class VisionFlashAttention2(nn.Module):
175
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
176
+ super().__init__()
177
+ self.num_heads = num_heads
178
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
179
+ self.proj = nn.Linear(dim, dim)
180
+
181
+ def forward(
182
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
183
+ ) -> torch.Tensor:
184
+ seq_length = hidden_states.shape[0]
185
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
186
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
187
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
188
+
189
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
190
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
191
+ seq_length, -1
192
+ )
193
+ attn_output = self.proj(attn_output)
194
+ return attn_output
195
+
196
+
197
+ class VisionSdpaAttention(nn.Module):
198
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
199
+ super().__init__()
200
+ self.num_heads = num_heads
201
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
202
+ self.proj = nn.Linear(dim, dim)
203
+
204
+ def forward(
205
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
206
+ ) -> torch.Tensor:
207
+ seq_length = hidden_states.shape[0]
208
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
209
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
210
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
211
+
212
+ attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
213
+ for i in range(1, len(cu_seqlens)):
214
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
215
+ q = q.transpose(0, 1)
216
+ k = k.transpose(0, 1)
217
+ v = v.transpose(0, 1)
218
+ attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
219
+ attn_output = attn_output.transpose(0, 1)
220
+ attn_output = attn_output.reshape(seq_length, -1)
221
+ attn_output = self.proj(attn_output)
222
+ return attn_output
223
+
224
+
225
+ QWEN2_VL_VISION_ATTENTION_CLASSES = {
226
+ "eager": VisionAttention,
227
+ "flash_attention_2": VisionFlashAttention2,
228
+ "sdpa": VisionSdpaAttention,
229
+ }
230
+
231
+
232
+ class Qwen2VLVisionBlock(nn.Module):
233
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
234
+ super().__init__()
235
+ self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
236
+ self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
237
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
238
+
239
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
240
+ config.embed_dim, num_heads=config.num_heads
241
+ )
242
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
243
+
244
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
245
+ hidden_states = hidden_states + self.attn(
246
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
247
+ )
248
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
249
+ return hidden_states
250
+
251
+
252
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
253
+ config_class = Qwen2VLVisionConfig
254
+ base_model_prefix = "model"
255
+ supports_gradient_checkpointing = True
256
+ _no_split_modules = ["Qwen2VLVisionBlock"]
257
+ _skip_keys_device_placement = "past_key_values"
258
+ _supports_flash_attn_2 = True
259
+ _supports_sdpa = True
260
+ _supports_cache_class = True
261
+ _supports_static_cache = True
262
+
263
+ def _init_weights(self, module):
264
+ std = self.config.initializer_range
265
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
266
+ module.weight.data.normal_(mean=0.0, std=std)
267
+ if module.bias is not None:
268
+ module.bias.data.zero_()
269
+ elif isinstance(module, nn.Embedding):
270
+ module.weight.data.normal_(mean=0.0, std=std)
271
+ if module.padding_idx is not None:
272
+ module.weight.data[module.padding_idx].zero_()
273
+
274
+
275
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
276
+ config_class = Qwen2VLVisionConfig
277
+ _no_split_modules = ["Qwen2VLVisionBlock"]
278
+
279
+ def __init__(self, config) -> None:
280
+ super().__init__(config)
281
+ self.spatial_merge_size = config.spatial_merge_size
282
+ self.gradient_checkpointing = False
283
+
284
+ self.patch_embed = PatchEmbed(
285
+ patch_size=config.patch_size,
286
+ temporal_patch_size=config.temporal_patch_size,
287
+ in_channels=config.in_channels,
288
+ embed_dim=config.embed_dim,
289
+ )
290
+
291
+ head_dim = config.embed_dim // config.num_heads
292
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
293
+
294
+ self.blocks = nn.ModuleList(
295
+ [Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
296
+ )
297
+ #
298
+ # if self.spatial_merge_size > 1:
299
+ # self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
300
+
301
+ def get_dtype(self) -> torch.dtype:
302
+ return self.blocks[0].mlp.fc2.weight.dtype
303
+
304
+ def get_device(self) -> torch.device:
305
+ return self.blocks[0].mlp.fc2.weight.device
306
+
307
+ def rot_pos_emb(self, grid_thw, strides):
308
+ pos_ids = []
309
+ for (t, h, w), stride in zip(grid_thw, strides):
310
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
311
+ hpos_ids = hpos_ids.reshape(
312
+ h // stride,
313
+ stride,
314
+ w // stride,
315
+ stride,
316
+ )
317
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
318
+ hpos_ids = hpos_ids.flatten()
319
+
320
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
321
+ wpos_ids = wpos_ids.reshape(
322
+ h // stride,
323
+ stride,
324
+ w // stride,
325
+ stride,
326
+ )
327
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
328
+ wpos_ids = wpos_ids.flatten()
329
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
330
+ pos_ids = torch.cat(pos_ids, dim=0)
331
+ max_grid_size = grid_thw[:, 1:].max()
332
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
333
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
334
+ return rotary_pos_emb
335
+
336
+ def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
337
+ hidden_states = self.patch_embed(hidden_states)
338
+
339
+ # BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx`
340
+ # rotary_pos_emb = []
341
+ # for thw in grid_thws:
342
+ # rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0))
343
+ # rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0)
344
+ # grid_thws = torch.cat(grid_thws, dim = 0)
345
+
346
+ # new version of creating rotary position embedding
347
+ # grid_thws shapes like [batch_flatten_image_num, 3]
348
+ # grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py`
349
+ rotary_pos_emb = self.rot_pos_emb(grid_thws, strides)
350
+
351
+ cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32)
352
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
353
+
354
+ for blk in self.blocks:
355
+ if self.gradient_checkpointing and self.training:
356
+ hidden_states = self._gradient_checkpointing_func(
357
+ blk.__call__,
358
+ hidden_states,
359
+ cu_seqlens,
360
+ rotary_pos_emb
361
+ )
362
+ else:
363
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
364
+
365
+ # if self.spatial_merge_size > 1:
366
+ # hidden_states = self.merger(hidden_states)
367
+ return hidden_states