Spaces:
Running
on
Zero
Running
on
Zero
init
#1
by
CircleRadon
- opened
This view is limited to 50 files because it contains too many changes.
See the raw diff here.
- .DS_Store +0 -0
- .gitattributes +14 -0
- app.py +562 -0
- demo/.DS_Store +0 -0
- demo/images/1.jpg +3 -0
- demo/images/2.jpg +3 -0
- demo/images/3.jpg +3 -0
- demo/images/4.jpg +3 -0
- demo/images/5.jpg +3 -0
- demo/images/6.jpg +3 -0
- demo/images/7.jpg +3 -0
- demo/images/8.jpg +3 -0
- demo/images/LICENSE +3 -0
- demo/videos/1.mp4 +3 -0
- demo/videos/2.mp4 +3 -0
- demo/videos/3.mp4 +3 -0
- demo/videos/4.mp4 +3 -0
- requirements.txt +48 -0
- videollama3/.DS_Store +0 -0
- videollama3/__init__.py +239 -0
- videollama3/constants.py +46 -0
- videollama3/infer.py +82 -0
- videollama3/mm_utils.py +704 -0
- videollama3/model/__init__.py +166 -0
- videollama3/model/__pycache__/__init__.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/encoder.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/processor.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/projector.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/region_encoder.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/videollama3_arch.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/videollama3_qwen2.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__init__.py +3 -0
- videollama3/model/damovl_encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__pycache__/configuration_damovl_encoder.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__pycache__/image_processing.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__pycache__/modeling_damovl_encoder.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/configuration_damovl_encoder.py +71 -0
- videollama3/model/damovl_encoder/image_processing.py +472 -0
- videollama3/model/damovl_encoder/modeling_damovl_encoder.py +542 -0
- videollama3/model/encoder.py +385 -0
- videollama3/model/processor.py +366 -0
- videollama3/model/projector.py +160 -0
- videollama3/model/qwen2vl_encoder/__init__.py +3 -0
- videollama3/model/qwen2vl_encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/__pycache__/configuration_qwen2vl_encoder.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/__pycache__/image_processing.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/__pycache__/modeling_qwen2vl_encoder.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/configuration_qwen2vl_encoder.py +72 -0
- videollama3/model/qwen2vl_encoder/image_processing.py +469 -0
- videollama3/model/qwen2vl_encoder/modeling_qwen2vl_encoder.py +367 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
demo/videos/ filter=lfs diff=lfs merge=lfs -text
|
37 |
+
demo/videos/3.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
demo/videos/4.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
demo/videos/1.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
demo/videos/2.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
demo/images/4.jpg filter=lfs diff=lfs merge=lfs -text
|
42 |
+
demo/images/5.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
demo/images/6.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
demo/images/8.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
demo/images/1.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
demo/images/2.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
demo/images/3.jpg filter=lfs diff=lfs merge=lfs -text
|
48 |
+
demo/images/7.jpg filter=lfs diff=lfs merge=lfs -text
|
49 |
+
demo/images/LICENSE filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from transformers import SamModel, SamProcessor
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
import argparse
|
9 |
+
import sys
|
10 |
+
# This is for making model initialization faster and has no effect since we are loading the weights
|
11 |
+
sys.path.append('./')
|
12 |
+
from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
|
13 |
+
from videollama3.mm_utils import load_images
|
14 |
+
from videollama3.mm_utils import load_video
|
15 |
+
|
16 |
+
|
17 |
+
color_rgb = (1.0, 1.0, 1.0)
|
18 |
+
color_rgbs = [
|
19 |
+
(1.0, 1.0, 1.0),
|
20 |
+
(1.0, 0.0, 0.0),
|
21 |
+
(0.0, 1.0, 1.0),
|
22 |
+
(0.0, 1.0, 0.0),
|
23 |
+
(0.0, 0.0, 1.0),
|
24 |
+
(1.0, 0.0, 1.0),
|
25 |
+
]
|
26 |
+
|
27 |
+
mask_list = []
|
28 |
+
mask_raw_list = []
|
29 |
+
mask_list_video = []
|
30 |
+
mask_raw_list_video = []
|
31 |
+
|
32 |
+
def extract_first_frame_from_video(video):
|
33 |
+
cap = cv2.VideoCapture(video)
|
34 |
+
success, frame = cap.read()
|
35 |
+
cap.release()
|
36 |
+
if success:
|
37 |
+
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
38 |
+
return None
|
39 |
+
|
40 |
+
def extract_points_from_mask(mask_pil):
|
41 |
+
mask = np.asarray(mask_pil)[..., 0]
|
42 |
+
coords = np.nonzero(mask)
|
43 |
+
coords = np.stack((coords[1], coords[0]), axis=1)
|
44 |
+
|
45 |
+
return coords
|
46 |
+
|
47 |
+
def add_contour(img, mask, color=(1., 1., 1.)):
|
48 |
+
img = img.copy()
|
49 |
+
|
50 |
+
mask = mask.astype(np.uint8) * 255
|
51 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
52 |
+
cv2.drawContours(img, contours, -1, color, thickness=8)
|
53 |
+
|
54 |
+
return img
|
55 |
+
|
56 |
+
def generate_masks(image):
|
57 |
+
global mask_list
|
58 |
+
global mask_raw_list
|
59 |
+
image['image'] = image['background'].convert('RGB')
|
60 |
+
# del image['background'], image['composite']
|
61 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
62 |
+
|
63 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
64 |
+
points = extract_points_from_mask(mask)
|
65 |
+
np.random.seed(0)
|
66 |
+
if points.shape[0] == 0:
|
67 |
+
raise gr.Error("No points selected")
|
68 |
+
|
69 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
70 |
+
points = points[points_selected_indices]
|
71 |
+
coords = [points.tolist()]
|
72 |
+
mask_np = apply_sam(image['image'], coords)
|
73 |
+
|
74 |
+
mask_raw_list.append(mask_np)
|
75 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
|
76 |
+
|
77 |
+
mask_list.append((mask_image, f"<region{len(mask_list)}>"))
|
78 |
+
# Return a list containing the mask image.
|
79 |
+
image['layers'] = []
|
80 |
+
image['composite'] = image['background']
|
81 |
+
return mask_list, image
|
82 |
+
|
83 |
+
|
84 |
+
def generate_masks_video(image):
|
85 |
+
global mask_list_video
|
86 |
+
global mask_raw_list_video
|
87 |
+
image['image'] = image['background'].convert('RGB')
|
88 |
+
# del image['background'], image['composite']
|
89 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
90 |
+
|
91 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
92 |
+
points = extract_points_from_mask(mask)
|
93 |
+
np.random.seed(0)
|
94 |
+
if points.shape[0] == 0:
|
95 |
+
raise gr.Error("No points selected")
|
96 |
+
|
97 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
98 |
+
points = points[points_selected_indices]
|
99 |
+
coords = [points.tolist()]
|
100 |
+
mask_np = apply_sam(image['image'], coords)
|
101 |
+
|
102 |
+
mask_raw_list_video.append(mask_np)
|
103 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
|
104 |
+
|
105 |
+
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
106 |
+
# Return a list containing the mask image.
|
107 |
+
image['layers'] = []
|
108 |
+
image['composite'] = image['background']
|
109 |
+
return mask_list_video, image
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def describe(image, mode, query, masks):
|
114 |
+
# Create an image object from the uploaded image
|
115 |
+
# print(image.keys())
|
116 |
+
|
117 |
+
image['image'] = image['background'].convert('RGB')
|
118 |
+
# del image['background'], image['composite']
|
119 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
120 |
+
|
121 |
+
# Handle both hex and rgba color formats
|
122 |
+
|
123 |
+
img_np = np.asarray(image['image']).astype(float) / 255.
|
124 |
+
if mode=='Caption':
|
125 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
126 |
+
|
127 |
+
points = extract_points_from_mask(mask)
|
128 |
+
|
129 |
+
np.random.seed(0)
|
130 |
+
|
131 |
+
if points.shape[0] == 0:
|
132 |
+
if len(masks)>1:
|
133 |
+
raise gr.Error("No points selected")
|
134 |
+
|
135 |
+
else:
|
136 |
+
# Randomly sample 8 points from the mask
|
137 |
+
# Follow DAM https://github.com/NVlabs/describe-anything
|
138 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
139 |
+
points = points[points_selected_indices]
|
140 |
+
|
141 |
+
coords = [points.tolist()]
|
142 |
+
|
143 |
+
mask_np = apply_sam(image['image'], coords)
|
144 |
+
|
145 |
+
masks = []
|
146 |
+
masks.append(mask_np)
|
147 |
+
mask_ids = [0]
|
148 |
+
|
149 |
+
img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
|
150 |
+
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
151 |
+
else:
|
152 |
+
masks = mask_raw_list
|
153 |
+
img_with_contour_np = img_np.copy()
|
154 |
+
|
155 |
+
mask_ids = []
|
156 |
+
for i, mask_np in enumerate(masks):
|
157 |
+
img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
|
158 |
+
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
159 |
+
mask_ids.append(0)
|
160 |
+
|
161 |
+
masks = np.stack(masks, axis=0)
|
162 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
img = np.asarray(image['image'])
|
167 |
+
|
168 |
+
|
169 |
+
if mode == "Caption":
|
170 |
+
query = '<image>\nPlease describe the <region> in the image in detail.'
|
171 |
+
else:
|
172 |
+
if len(masks)==1:
|
173 |
+
prefix = "<image>\nThere is 1 region in the image: <region0> <region>. "
|
174 |
+
else:
|
175 |
+
prefix = f"<image>\nThere is {len(masks)} region in the image: "
|
176 |
+
for i in range(len(masks)):
|
177 |
+
prefix += f"<region{i}><region>, "
|
178 |
+
prefix = prefix[:-2]+'. '
|
179 |
+
query = prefix + query
|
180 |
+
# print(query)
|
181 |
+
|
182 |
+
image['layers'] = []
|
183 |
+
image['composite'] = image['background']
|
184 |
+
|
185 |
+
text = ""
|
186 |
+
yield img_with_contour_pil, text, image
|
187 |
+
|
188 |
+
for token in get_model_output(
|
189 |
+
[img],
|
190 |
+
query,
|
191 |
+
model=model,
|
192 |
+
tokenizer=tokenizer,
|
193 |
+
masks=masks,
|
194 |
+
mask_ids=mask_ids,
|
195 |
+
modal='image',
|
196 |
+
image_downsampling=1,
|
197 |
+
streaming=True,
|
198 |
+
):
|
199 |
+
text += token
|
200 |
+
yield gr.update(), text, gr.update()
|
201 |
+
|
202 |
+
|
203 |
+
def load_first_frame(video_path):
|
204 |
+
cap = cv2.VideoCapture(video_path)
|
205 |
+
ret, frame = cap.read()
|
206 |
+
cap.release()
|
207 |
+
if not ret:
|
208 |
+
raise gr.Error("Could not read the video file.")
|
209 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
210 |
+
image = Image.fromarray(frame)
|
211 |
+
return image
|
212 |
+
|
213 |
+
def describe_video(video_path, mode, query, annotated_frame, masks):
|
214 |
+
global mask_list_video
|
215 |
+
# Create a temporary directory to save extracted video frames
|
216 |
+
cap = cv2.VideoCapture(video_path)
|
217 |
+
|
218 |
+
video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0])
|
219 |
+
|
220 |
+
annotated_frame['image'] = annotated_frame['background'].convert('RGB')
|
221 |
+
|
222 |
+
# Process the annotated frame from the image editor
|
223 |
+
if isinstance(annotated_frame, dict):
|
224 |
+
# Get the composite image with annotations
|
225 |
+
frame_img = annotated_frame.get("image", annotated_frame.get("background"))
|
226 |
+
if frame_img is None:
|
227 |
+
raise gr.Error("No valid annotation found in the image editor.")
|
228 |
+
frame_img = frame_img.convert("RGB")
|
229 |
+
|
230 |
+
# Get the annotation layer
|
231 |
+
if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0:
|
232 |
+
mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB")
|
233 |
+
else:
|
234 |
+
mask = Image.new("RGB", frame_img.size, 0)
|
235 |
+
else:
|
236 |
+
frame_img = annotated_frame.convert("RGB")
|
237 |
+
mask = Image.new("RGB", frame_img.size, 0)
|
238 |
+
|
239 |
+
img_np = np.asarray(annotated_frame['image']).astype(float) / 255.
|
240 |
+
# Extract points from the annotated mask (using the first channel)
|
241 |
+
if mode == "Caption":
|
242 |
+
points = extract_points_from_mask(mask)
|
243 |
+
np.random.seed(0)
|
244 |
+
if points.shape[0] == 0:
|
245 |
+
raise gr.Error("No points were selected in the annotation.")
|
246 |
+
# Randomly select up to 8 points
|
247 |
+
# Follow DAM https://github.com/NVlabs/describe-anything
|
248 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
249 |
+
points = points[points_selected_indices]
|
250 |
+
|
251 |
+
# print(f"Selected points (to SAM): {points}")
|
252 |
+
|
253 |
+
coords = [points.tolist()]
|
254 |
+
|
255 |
+
mask_np = apply_sam(annotated_frame['image'], coords)
|
256 |
+
|
257 |
+
masks = []
|
258 |
+
masks.append(mask_np)
|
259 |
+
mask_ids = [0]
|
260 |
+
|
261 |
+
# img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
|
262 |
+
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
263 |
+
|
264 |
+
|
265 |
+
else:
|
266 |
+
masks = mask_raw_list_video
|
267 |
+
img_with_contour_np = img_np.copy()
|
268 |
+
|
269 |
+
mask_ids = []
|
270 |
+
for i, mask_np in enumerate(masks):
|
271 |
+
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
|
272 |
+
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
273 |
+
mask_ids.append(0)
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
masks = np.stack(masks, axis=0)
|
278 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
if mode == "Caption":
|
284 |
+
query = '<video>\nPlease describe the <region> in the video in detail.'
|
285 |
+
else:
|
286 |
+
if len(masks)==1:
|
287 |
+
prefix = "<video>\nThere is 1 object in the video: <object0> <region>. "
|
288 |
+
else:
|
289 |
+
prefix = f"<video>\nThere is {len(masks)} objects in the video: "
|
290 |
+
for i in range(len(masks)):
|
291 |
+
prefix += f"<object{i}><region>, "
|
292 |
+
prefix = prefix[:-2]+'. '
|
293 |
+
query = prefix + query
|
294 |
+
|
295 |
+
# Initialize empty text
|
296 |
+
# text = description_generator
|
297 |
+
annotated_frame['layers'] = []
|
298 |
+
annotated_frame['composite'] = annotated_frame['background']
|
299 |
+
|
300 |
+
if mode=="Caption":
|
301 |
+
mask_list_video = []
|
302 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
303 |
+
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
304 |
+
text = ""
|
305 |
+
yield frame_img, text, mask_list_video
|
306 |
+
|
307 |
+
for token in get_model_output(
|
308 |
+
video_tensor,
|
309 |
+
query,
|
310 |
+
model=model,
|
311 |
+
tokenizer=tokenizer,
|
312 |
+
masks=masks,
|
313 |
+
mask_ids=mask_ids,
|
314 |
+
modal='video',
|
315 |
+
streaming=True,
|
316 |
+
):
|
317 |
+
text += token
|
318 |
+
yield gr.update(), text, gr.update()
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
def apply_sam(image, input_points):
|
323 |
+
inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
|
324 |
+
|
325 |
+
with torch.no_grad():
|
326 |
+
outputs = sam_model(**inputs)
|
327 |
+
|
328 |
+
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
|
329 |
+
scores = outputs.iou_scores[0, 0]
|
330 |
+
|
331 |
+
mask_selection_index = scores.argmax()
|
332 |
+
|
333 |
+
mask_np = masks[mask_selection_index].numpy()
|
334 |
+
|
335 |
+
return mask_np
|
336 |
+
|
337 |
+
def clear_masks():
|
338 |
+
global mask_list
|
339 |
+
global mask_raw_list
|
340 |
+
mask_list = []
|
341 |
+
mask_raw_list = []
|
342 |
+
return []
|
343 |
+
|
344 |
+
|
345 |
+
def clear_masks_video():
|
346 |
+
global mask_list_video
|
347 |
+
global mask_raw_list_video
|
348 |
+
mask_list_video = []
|
349 |
+
mask_raw_list_video = []
|
350 |
+
return []
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
|
355 |
+
parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoRefer-VideoLLaMA3-7B", help="Path to the model checkpoint")
|
356 |
+
parser.add_argument("--prompt-mode", type=str, default="focal_prompt", help="Prompt mode")
|
357 |
+
parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode")
|
358 |
+
parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
|
359 |
+
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
|
360 |
+
|
361 |
+
args_cli = parser.parse_args()
|
362 |
+
print(args_cli.model_path)
|
363 |
+
|
364 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
365 |
+
|
366 |
+
HEADER = ("""
|
367 |
+
<div>
|
368 |
+
<h1>VideoRefer X VideoLLaMA3 Demo</h1>
|
369 |
+
<h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
370 |
+
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
|
371 |
+
</div>
|
372 |
+
</div>
|
373 |
+
<div style="display: flex; justify-content: left; margin-top: 10px;">
|
374 |
+
<a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
|
375 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
|
376 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
|
377 |
+
</div>
|
378 |
+
""")
|
379 |
+
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Column():
|
382 |
+
gr.HTML(HEADER)
|
383 |
+
|
384 |
+
|
385 |
+
image_tips = """
|
386 |
+
### 💡 Tips:
|
387 |
+
|
388 |
+
🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
|
389 |
+
|
390 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
391 |
+
|
392 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
393 |
+
|
394 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
395 |
+
|
396 |
+
"""
|
397 |
+
|
398 |
+
video_tips = """
|
399 |
+
### 💡 Tips:
|
400 |
+
⚠️ For video mode, we only support masking on the first frame in this demo.
|
401 |
+
|
402 |
+
🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
|
403 |
+
|
404 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
405 |
+
|
406 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
407 |
+
|
408 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
409 |
+
|
410 |
+
"""
|
411 |
+
|
412 |
+
|
413 |
+
with gr.TabItem("Image"):
|
414 |
+
with gr.Row():
|
415 |
+
with gr.Column():
|
416 |
+
image_input = gr.ImageEditor(
|
417 |
+
label="Image",
|
418 |
+
type="pil",
|
419 |
+
sources=['upload'],
|
420 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
421 |
+
eraser=True,
|
422 |
+
layers=False,
|
423 |
+
transforms=[],
|
424 |
+
height=300,
|
425 |
+
)
|
426 |
+
generate_mask_btn = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
427 |
+
mode = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
428 |
+
query = gr.Textbox(label="Question", value="What is the relationship between <region0> and <region1>?", interactive=True, visible=False)
|
429 |
+
|
430 |
+
submit_btn = gr.Button("Generate Caption", variant="primary")
|
431 |
+
submit_btn1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
432 |
+
gr.Examples([f"./demo/images/{i+1}.jpg" for i in range(8)], inputs=image_input, label="Examples")
|
433 |
+
|
434 |
+
with gr.Column():
|
435 |
+
mask_output = gr.Gallery(label="Referred Masks", object_fit='scale-down', visible=False)
|
436 |
+
output_image = gr.Image(label="Image with Mask", visible=True, height=400)
|
437 |
+
description = gr.Textbox(label="Output", visible=True)
|
438 |
+
|
439 |
+
clear_masks_btn = gr.Button("Clear Masks", variant="secondary", visible=False)
|
440 |
+
gr.Markdown(image_tips)
|
441 |
+
|
442 |
+
with gr.TabItem("Video"):
|
443 |
+
with gr.Row():
|
444 |
+
with gr.Column():
|
445 |
+
video_input = gr.Video(label="Video")
|
446 |
+
# load_btn = gr.Button("🖼️ Load First Frame", variant="secondary")
|
447 |
+
first_frame = gr.ImageEditor(
|
448 |
+
label="Annotate First Frame",
|
449 |
+
type="pil",
|
450 |
+
sources=['upload'],
|
451 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
452 |
+
eraser=True,
|
453 |
+
layers=False,
|
454 |
+
transforms=[],
|
455 |
+
height=300,
|
456 |
+
)
|
457 |
+
generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
458 |
+
gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
|
459 |
+
|
460 |
+
with gr.Column():
|
461 |
+
mode_video = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
462 |
+
mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
|
463 |
+
|
464 |
+
query_video = gr.Textbox(label="Question", value="What is the relationship between <object0> and <object1>?", interactive=True, visible=False)
|
465 |
+
|
466 |
+
submit_btn_video = gr.Button("Generate Caption", variant="primary")
|
467 |
+
submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
468 |
+
description_video = gr.Textbox(label="Output", visible=True)
|
469 |
+
|
470 |
+
clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
|
471 |
+
|
472 |
+
gr.Markdown(video_tips)
|
473 |
+
|
474 |
+
|
475 |
+
def toggle_query_and_generate_button(mode):
|
476 |
+
query_visible = mode == "QA"
|
477 |
+
caption_visible = mode == "Caption"
|
478 |
+
global mask_list
|
479 |
+
global mask_raw_list
|
480 |
+
mask_list = []
|
481 |
+
mask_raw_list = []
|
482 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], ""
|
483 |
+
|
484 |
+
video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
|
485 |
+
|
486 |
+
mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description])
|
487 |
+
|
488 |
+
def toggle_query_and_generate_button_video(mode):
|
489 |
+
query_visible = mode == "QA"
|
490 |
+
caption_visible = mode == "Caption"
|
491 |
+
global mask_list_video
|
492 |
+
global mask_raw_list_video
|
493 |
+
mask_list_video = []
|
494 |
+
mask_raw_list_video = []
|
495 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), []
|
496 |
+
|
497 |
+
|
498 |
+
mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video])
|
499 |
+
|
500 |
+
submit_btn.click(
|
501 |
+
fn=describe,
|
502 |
+
inputs=[image_input, mode, query],
|
503 |
+
outputs=[output_image, description, image_input],
|
504 |
+
api_name="describe"
|
505 |
+
)
|
506 |
+
|
507 |
+
submit_btn1.click(
|
508 |
+
fn=describe,
|
509 |
+
inputs=[image_input, mode, query],
|
510 |
+
outputs=[output_image, description, image_input],
|
511 |
+
api_name="describe"
|
512 |
+
)
|
513 |
+
|
514 |
+
generate_mask_btn.click(
|
515 |
+
fn=generate_masks,
|
516 |
+
inputs=[image_input],
|
517 |
+
outputs=[mask_output, image_input]
|
518 |
+
)
|
519 |
+
|
520 |
+
generate_mask_btn_video.click(
|
521 |
+
fn=generate_masks_video,
|
522 |
+
inputs=[first_frame],
|
523 |
+
outputs=[mask_output_video, first_frame]
|
524 |
+
)
|
525 |
+
|
526 |
+
clear_masks_btn.click(
|
527 |
+
fn=clear_masks,
|
528 |
+
outputs=[mask_output]
|
529 |
+
)
|
530 |
+
|
531 |
+
clear_masks_btn_video.click(
|
532 |
+
fn=clear_masks_video,
|
533 |
+
outputs=[mask_output_video]
|
534 |
+
)
|
535 |
+
|
536 |
+
submit_btn_video.click(
|
537 |
+
fn=describe_video,
|
538 |
+
inputs=[video_input, mode_video, query_video, first_frame],
|
539 |
+
outputs=[first_frame, description_video, mask_output_video],
|
540 |
+
api_name="describe_video"
|
541 |
+
)
|
542 |
+
|
543 |
+
submit_btn_video1.click(
|
544 |
+
fn=describe_video,
|
545 |
+
inputs=[video_input, mode_video, query_video, first_frame],
|
546 |
+
outputs=[first_frame, description_video, mask_output_video],
|
547 |
+
api_name="describe_video"
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
553 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
554 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
555 |
+
|
556 |
+
disable_torch_init()
|
557 |
+
|
558 |
+
|
559 |
+
model, processor, tokenizer = model_init(args_cli.model_path)
|
560 |
+
|
561 |
+
|
562 |
+
demo.launch()
|
demo/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
demo/images/1.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/2.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/3.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/4.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/5.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/6.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/7.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/8.jpg
ADDED
![]() |
Git LFS Details
|
demo/images/LICENSE
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ac4c813c90895cdc79c71fdbd02715fd0c5505c24d95c5941747c904d6e93bc
|
3 |
+
size 149
|
demo/videos/1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ad78d268f6f1ad9a457a7768665157f74c20292136cefbf6bfc2a07de940dd0a
|
3 |
+
size 804232
|
demo/videos/2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5eebbd330be490709c1b39cd1d82ae074f3fe275487bc6b77d2aa5cd74d40d05
|
3 |
+
size 1255466
|
demo/videos/3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:946550c741c9dc515340ab93b203614094632191db0d8f9697bd580f4a271947
|
3 |
+
size 8743247
|
demo/videos/4.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b06b309812947b909ce7b8eaaea94a9ca60a8452a33e3109f5f6ffb1dbf8ee6
|
3 |
+
size 1334796
|
requirements.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
4 |
+
# basic dependencies
|
5 |
+
torch==2.4.0
|
6 |
+
torchvision==0.19.0
|
7 |
+
datasets==2.21.0
|
8 |
+
transformers==4.46.3
|
9 |
+
tokenizers==0.20.3
|
10 |
+
deepspeed==0.15.4
|
11 |
+
accelerate==1.0.1
|
12 |
+
peft==0.4.0
|
13 |
+
timm==1.0.3
|
14 |
+
numpy==1.24.4
|
15 |
+
# data processing
|
16 |
+
decord==0.6.0
|
17 |
+
imageio==2.34.0
|
18 |
+
imageio-ffmpeg==0.4.9
|
19 |
+
moviepy==1.0.3
|
20 |
+
scenedetect==0.6.3
|
21 |
+
opencv-python==4.6.0.66
|
22 |
+
pyarrow
|
23 |
+
pysubs2
|
24 |
+
ffmpeg-python
|
25 |
+
# misc
|
26 |
+
scikit-learn==1.2.2
|
27 |
+
huggingface_hub==0.23.4
|
28 |
+
sentencepiece==0.1.99
|
29 |
+
shortuuid
|
30 |
+
einops==0.6.1
|
31 |
+
einops-exts==0.0.4
|
32 |
+
bitsandbytes==0.43.3 # for cuda 124
|
33 |
+
pydantic>=2.0
|
34 |
+
markdown2[all]
|
35 |
+
gradio==5.34.0
|
36 |
+
gradio_client==1.10.3
|
37 |
+
httpx==0.24.1
|
38 |
+
requests
|
39 |
+
openai
|
40 |
+
uvicorn
|
41 |
+
fastapi
|
42 |
+
tensorboard
|
43 |
+
wandb
|
44 |
+
tabulate
|
45 |
+
Levenshtein
|
46 |
+
pycocotools==2.0.8
|
47 |
+
spaces
|
48 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
videollama3/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
videollama3/__init__.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
import shutil
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .model import load_pretrained_model
|
11 |
+
from .model.processor import Videollama3Processor
|
12 |
+
from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, resize_image_mask
|
13 |
+
from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN
|
14 |
+
from videollama3.constants import REGION_TOKEN
|
15 |
+
from transformers import TextIteratorStreamer
|
16 |
+
from threading import Thread
|
17 |
+
|
18 |
+
def disable_torch_init():
|
19 |
+
"""
|
20 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
21 |
+
"""
|
22 |
+
import torch
|
23 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
24 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
25 |
+
|
26 |
+
|
27 |
+
def model_init(model_path=None, **kwargs):
|
28 |
+
model_path = "DAMO-NLP-SG/VideoLLaMA2-7B" if model_path is None else model_path
|
29 |
+
model_name = get_model_name_from_path(model_path)
|
30 |
+
tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)
|
31 |
+
|
32 |
+
if tokenizer.pad_token is None and tokenizer.unk_token is not None:
|
33 |
+
tokenizer.pad_token = tokenizer.unk_token
|
34 |
+
|
35 |
+
aspect_ratio = model.config.image_aspect_ratio if hasattr(model.config, "image_aspect_ratio") else "pad"
|
36 |
+
image_size = model.config.image_size if hasattr(model.config, "image_size") else 384
|
37 |
+
# NOTE: If num_frames is None, the frame sampling mode is "fps". If num_frames is not None, the frame sampling mode is "uniform".
|
38 |
+
# num_frames = model.config.num_frames
|
39 |
+
model.config.region_token_index = tokenizer.convert_tokens_to_ids(REGION_TOKEN)
|
40 |
+
processor = {
|
41 |
+
'image': load_images,
|
42 |
+
'video': load_video,
|
43 |
+
'text': None
|
44 |
+
}
|
45 |
+
|
46 |
+
return model, processor, tokenizer
|
47 |
+
|
48 |
+
|
49 |
+
def get_model_output(images_or_videos, instruct, model, tokenizer, modal='video', **kwargs):
|
50 |
+
streaming = kwargs.pop('streaming', False)
|
51 |
+
if streaming:
|
52 |
+
return mm_infer(images_or_videos, instruct, model, tokenizer, modal, streaming=True, **kwargs)
|
53 |
+
else:
|
54 |
+
output = mm_infer(images_or_videos, instruct, model, tokenizer, modal, streaming=False, **kwargs)
|
55 |
+
return next(output)
|
56 |
+
|
57 |
+
|
58 |
+
def mm_infer(images_or_videos, instruct, model, tokenizer, modal='video', **kwargs):
|
59 |
+
"""inference api of VideoLLaMA2 for video understanding.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
model: VideoLLaMA2 model.
|
63 |
+
images_or_videos (torch.Tensor): image tensor (1, C, H, W) / video tensor (T, C, H, W).
|
64 |
+
instruct (str): text instruction for understanding video.
|
65 |
+
tokenizer: tokenizer.
|
66 |
+
do_sample (bool): whether to sample.
|
67 |
+
modal (str): inference modality.
|
68 |
+
Returns:
|
69 |
+
str: response of the model.
|
70 |
+
"""
|
71 |
+
mask_ids = kwargs.pop('mask_ids', None)
|
72 |
+
masks = kwargs.pop('masks', None)
|
73 |
+
streaming = kwargs.pop('streaming', False)
|
74 |
+
if modal == 'image':
|
75 |
+
modal_token = DEFAULT_IMAGE_TOKEN
|
76 |
+
images = images_or_videos
|
77 |
+
additional_frames = images.copy()
|
78 |
+
timestamps = None
|
79 |
+
elif modal == 'video':
|
80 |
+
modal_token = DEFAULT_VIDEO_TOKEN
|
81 |
+
images, timestamps, additional_frames = images_or_videos
|
82 |
+
elif modal == 'text':
|
83 |
+
modal_token = ''
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Unsupported modal: {modal}")
|
86 |
+
|
87 |
+
vlprocessor = Videollama3Processor(model.get_vision_encoder().image_processor, tokenizer)
|
88 |
+
vlprocessor.tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN, STREAM_START_TOKEN, STREAM_END_TOKEN], special_tokens=True)
|
89 |
+
|
90 |
+
model.config.image_token_index = vlprocessor.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
|
91 |
+
|
92 |
+
if masks is not None:
|
93 |
+
additional_frames, masks, mask_nums = resize_image_mask(additional_frames, masks, mask_ids)
|
94 |
+
|
95 |
+
for idx in range(len(mask_nums)):
|
96 |
+
instruct = instruct.replace('<region>', "["+REGION_TOKEN*mask_nums[idx]+"]", 1)
|
97 |
+
|
98 |
+
|
99 |
+
additional_images_dict = vlprocessor._process_image(additional_frames, image_downsampling=1)
|
100 |
+
additional_images = additional_images_dict['images']
|
101 |
+
# import pdb
|
102 |
+
# pdb.set_trace()
|
103 |
+
|
104 |
+
|
105 |
+
# flatten_patches1 = additional_images[0].reshape(26, 46, 3, -1)
|
106 |
+
# from matplotlib import pyplot as plt
|
107 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
108 |
+
# plt.savefig('16.png')
|
109 |
+
|
110 |
+
additional_images_thws = additional_images_dict['grid_thws']
|
111 |
+
additional_images = (additional_images, additional_images_thws)
|
112 |
+
|
113 |
+
else:
|
114 |
+
additional_images = None
|
115 |
+
|
116 |
+
|
117 |
+
# 1. text preprocess (tag process & generate prompt).
|
118 |
+
if isinstance(instruct, str):
|
119 |
+
messages = [{'role': 'user', 'content': instruct}]
|
120 |
+
elif isinstance(instruct, list):
|
121 |
+
messages = copy.deepcopy(instruct)
|
122 |
+
else:
|
123 |
+
raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
|
124 |
+
|
125 |
+
if all(not modal_token in message["content"] for message in messages):
|
126 |
+
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
|
127 |
+
messages[0]["content"] = modal_token + messages[0]["content"]
|
128 |
+
|
129 |
+
converted_messages = []
|
130 |
+
for message in messages:
|
131 |
+
chunks = message["content"].split(modal_token)
|
132 |
+
converted_messages.append({
|
133 |
+
"role": "user",
|
134 |
+
"content": []
|
135 |
+
})
|
136 |
+
|
137 |
+
for chunk_idx in range(1, 2 * len(chunks)):
|
138 |
+
if chunk_idx % 2 == 1:
|
139 |
+
chunk = chunks[chunk_idx // 2].strip()
|
140 |
+
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
|
141 |
+
else:
|
142 |
+
if modal == 'image':
|
143 |
+
converted_messages[-1]["content"].append({"type": "image"})
|
144 |
+
elif modal == 'video':
|
145 |
+
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
|
146 |
+
|
147 |
+
messages = converted_messages
|
148 |
+
|
149 |
+
# 2. vision preprocess (load & transform image or video).
|
150 |
+
if model.config.model_type in ['videollama3_mistral', 'videollama3_mixtral']:
|
151 |
+
system_message = [
|
152 |
+
{'role': 'system', 'content': (
|
153 |
+
"""<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."""
|
154 |
+
"""\n"""
|
155 |
+
"""If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>""")
|
156 |
+
}
|
157 |
+
]
|
158 |
+
else:
|
159 |
+
system_message = []
|
160 |
+
|
161 |
+
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
|
162 |
+
# TODO: attention mask?
|
163 |
+
messages = system_message + messages
|
164 |
+
data_dict = vlprocessor(
|
165 |
+
images=images,
|
166 |
+
text=messages,
|
167 |
+
image_downsampling=image_downsampling,
|
168 |
+
return_tensors="pt",
|
169 |
+
)
|
170 |
+
|
171 |
+
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
|
172 |
+
|
173 |
+
images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]]
|
174 |
+
grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]]
|
175 |
+
|
176 |
+
# 3. generate response according to visual signals and prompts.
|
177 |
+
keywords = [tokenizer.eos_token]
|
178 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"])
|
179 |
+
stop_str = tokenizer.eos_token
|
180 |
+
|
181 |
+
do_sample = kwargs.get('do_sample', False)
|
182 |
+
temperature = kwargs.get('temperature', 0.2 if do_sample else 0.0)
|
183 |
+
top_p = kwargs.get('top_p', 0.9)
|
184 |
+
max_new_tokens = kwargs.get('max_new_tokens', 2048)
|
185 |
+
if not streaming:
|
186 |
+
with torch.inference_mode():
|
187 |
+
output_ids = model.generate(
|
188 |
+
# input_ids,
|
189 |
+
# attention_mask=attention_masks,
|
190 |
+
# images=images,
|
191 |
+
data_dict["input_ids"].cuda(),
|
192 |
+
attention_mask=data_dict["attention_mask"].cuda(),
|
193 |
+
images=[(modal, images, grid_thws)],
|
194 |
+
do_sample=do_sample,
|
195 |
+
temperature=temperature,
|
196 |
+
max_new_tokens=max_new_tokens,
|
197 |
+
top_p=top_p,
|
198 |
+
use_cache=True,
|
199 |
+
stopping_criteria=[stopping_criteria],
|
200 |
+
pad_token_id=tokenizer.eos_token_id,
|
201 |
+
additional_images=[additional_images],
|
202 |
+
masks=[masks],
|
203 |
+
)
|
204 |
+
|
205 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
206 |
+
|
207 |
+
yield outputs
|
208 |
+
|
209 |
+
else:
|
210 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
211 |
+
generation_kwargs = dict(
|
212 |
+
inputs=data_dict["input_ids"].cuda(),
|
213 |
+
attention_mask=data_dict["attention_mask"].cuda(),
|
214 |
+
images=[(modal, images, grid_thws)],
|
215 |
+
do_sample=do_sample,
|
216 |
+
temperature=temperature,
|
217 |
+
max_new_tokens=max_new_tokens,
|
218 |
+
top_p=top_p,
|
219 |
+
use_cache=True,
|
220 |
+
stopping_criteria=[stopping_criteria],
|
221 |
+
pad_token_id=tokenizer.eos_token_id,
|
222 |
+
additional_images=[additional_images],
|
223 |
+
masks=[masks],
|
224 |
+
streamer=streamer
|
225 |
+
)
|
226 |
+
|
227 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
228 |
+
thread.start()
|
229 |
+
|
230 |
+
generated_text = ""
|
231 |
+
for new_text in streamer:
|
232 |
+
generated_text += new_text
|
233 |
+
if stop_str in generated_text:
|
234 |
+
generated_text = generated_text[:generated_text.find(stop_str)]
|
235 |
+
break
|
236 |
+
yield new_text
|
237 |
+
|
238 |
+
thread.join()
|
239 |
+
|
videollama3/constants.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
|
9 |
+
# Image arguments
|
10 |
+
IMAGE_TOKEN_INDEX = -200
|
11 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
12 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
13 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
14 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
15 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
16 |
+
|
17 |
+
# Video arguments
|
18 |
+
VIDEO_TOKEN_INDEX = -201
|
19 |
+
DEFAULT_VIDEO_TOKEN = "<video>"
|
20 |
+
NUM_FRAMES = 128
|
21 |
+
MAX_FRAMES = 768
|
22 |
+
NUM_FRAMES_PER_SECOND = 1
|
23 |
+
|
24 |
+
# Region arguments
|
25 |
+
REGION_TOKEN = "<REGION>"
|
26 |
+
|
27 |
+
# Audio arguments
|
28 |
+
AUDIO_TOKEN_INDEX = -202
|
29 |
+
DEFAULT_AUDIO_TOKEN = "<audio>"
|
30 |
+
|
31 |
+
# Stream arguments
|
32 |
+
STREAM_START_TOKEN = "<|stream_start|>"
|
33 |
+
STREAM_END_TOKEN = "<|stream_end|>"
|
34 |
+
STREAM_IMAGE_TOKEN = "<stream_image>"
|
35 |
+
STREAM_FPS = 2
|
36 |
+
STREAM_IMAGE_SIZE = 224
|
37 |
+
STREAM_DOWNSAMPLING = 4
|
38 |
+
STREAM_MAX_FRAMES = 400
|
39 |
+
|
40 |
+
MODAL_INDEX_MAP = {
|
41 |
+
"<image>": -200,
|
42 |
+
"<video>": -201,
|
43 |
+
"<audio>": -202,
|
44 |
+
}
|
45 |
+
|
46 |
+
subimage_token_num=196
|
videollama3/infer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import sys
|
7 |
+
sys.path.append('./')
|
8 |
+
from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
|
9 |
+
from videollama3.mm_utils import load_video
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
def infer_image(model, tokenizer):
|
15 |
+
image_path = 'demo/images/1.jpg'
|
16 |
+
image = Image.open(image_path)
|
17 |
+
image_data = np.array(image)
|
18 |
+
|
19 |
+
question = '<image>\nPlease describe the <region> in the image in detail.'
|
20 |
+
|
21 |
+
mask = np.load('demo/masks/demo0.npy')
|
22 |
+
masks = []
|
23 |
+
masks.append(mask)
|
24 |
+
masks = np.array(masks)
|
25 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
26 |
+
|
27 |
+
mask_ids = [0]*len(masks)
|
28 |
+
|
29 |
+
output = get_model_output(
|
30 |
+
[image_data],
|
31 |
+
question,
|
32 |
+
model=model,
|
33 |
+
tokenizer=tokenizer,
|
34 |
+
masks=masks,
|
35 |
+
mask_ids=mask_ids,
|
36 |
+
modal='image',
|
37 |
+
image_downsampling=1,
|
38 |
+
)
|
39 |
+
print(output)
|
40 |
+
|
41 |
+
def infer_video(model, tokenizer):
|
42 |
+
video_path = 'demo/videos/1.mp4'
|
43 |
+
question = '<video>\nPlease describe the <region> in the video in detail.'
|
44 |
+
|
45 |
+
frame_idx = 0 # mask from the first frame
|
46 |
+
video_tensor = load_video(video_path, fps=1, max_frames=768, frame_ids=[frame_idx])
|
47 |
+
|
48 |
+
mask = np.load('demo/masks/demo1.npy')
|
49 |
+
masks = []
|
50 |
+
masks.append(mask)
|
51 |
+
masks = np.array(masks)
|
52 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
53 |
+
|
54 |
+
mask_ids = [0]*len(masks)
|
55 |
+
|
56 |
+
output = get_model_output(
|
57 |
+
video_tensor,
|
58 |
+
question,
|
59 |
+
model=model,
|
60 |
+
tokenizer=tokenizer,
|
61 |
+
masks=masks,
|
62 |
+
mask_ids=mask_ids,
|
63 |
+
modal='video',
|
64 |
+
)
|
65 |
+
print(output)
|
66 |
+
|
67 |
+
def main():
|
68 |
+
disable_torch_init()
|
69 |
+
|
70 |
+
# fill in the model path here
|
71 |
+
model_path = '/mnt/workspace/workgroup/yuanyq/code/videollama3/ProjectX_region/work_dirs/VideoRefer-VideoLLaMA3-7B'
|
72 |
+
model, processor, tokenizer = model_init(model_path)
|
73 |
+
|
74 |
+
# image
|
75 |
+
infer_image(model, tokenizer)
|
76 |
+
|
77 |
+
# viideo
|
78 |
+
infer_video(model, tokenizer)
|
79 |
+
|
80 |
+
|
81 |
+
if __name__=='__main__':
|
82 |
+
main()
|
videollama3/mm_utils.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import math
|
5 |
+
import base64
|
6 |
+
import traceback
|
7 |
+
from io import BytesIO
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms.functional as VF
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import numpy as np
|
14 |
+
from transformers import StoppingCriteria
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import imageio
|
18 |
+
import ffmpeg
|
19 |
+
from PIL import Image
|
20 |
+
from decord import VideoReader, cpu
|
21 |
+
|
22 |
+
from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
|
23 |
+
from pycocotools import mask as maskUtils
|
24 |
+
|
25 |
+
def resize_image_mask(images, masks, mask_ids, patch_size=14):
|
26 |
+
resize_images = []
|
27 |
+
resize_masks = []
|
28 |
+
mask_nums = []
|
29 |
+
for i, mask in enumerate(masks):
|
30 |
+
image = images[mask_ids[i]]
|
31 |
+
h, w = image.shape[:2]
|
32 |
+
if mask.sum()==0:
|
33 |
+
print('mask is none...')
|
34 |
+
mask = torch.ones((h, w))
|
35 |
+
rows, cols = np.where(mask == 1)
|
36 |
+
|
37 |
+
min_row, max_row = rows.min(), rows.max()
|
38 |
+
min_col, max_col = cols.min(), cols.max()
|
39 |
+
|
40 |
+
bbox = (max(0,min_row-patch_size*2), max(0,min_col-patch_size*2), min(h-1, max_row+patch_size*2), min(w-1, max_col+patch_size*2))
|
41 |
+
mask_h = bbox[2] - bbox[0]
|
42 |
+
mask_w = bbox[3] - bbox[1]
|
43 |
+
cropping_img = image[bbox[0]: bbox[2], bbox[1]: bbox[3], :]
|
44 |
+
cropping_mask = mask[bbox[0]: bbox[2], bbox[1]: bbox[3]]
|
45 |
+
|
46 |
+
scale_rate = math.ceil(math.sqrt(1960/mask.sum()))
|
47 |
+
if scale_rate==1:
|
48 |
+
if (mask.sum()/196)>100:
|
49 |
+
scale_rate = math.sqrt((mask.sum()/196)/100)
|
50 |
+
scale_rate = 1/scale_rate
|
51 |
+
resize_h = math.ceil((mask_h*scale_rate)/patch_size) * patch_size
|
52 |
+
resize_w = math.ceil((mask_w*scale_rate)/patch_size) * patch_size
|
53 |
+
|
54 |
+
resize_img = cv2.resize(cropping_img, (resize_w, resize_h))
|
55 |
+
resize_mask = F.interpolate(cropping_mask[None, None], size=(resize_h//patch_size, resize_w//patch_size), mode='bilinear', align_corners=False)[0,0]
|
56 |
+
mask_nums.append(min(10, int(resize_mask.sum())))
|
57 |
+
|
58 |
+
resize_images.append(resize_img)
|
59 |
+
resize_masks.append(resize_mask)
|
60 |
+
|
61 |
+
return resize_images, resize_masks, mask_nums
|
62 |
+
|
63 |
+
def reshape_images_to_raw_grid(mm_features_raw, grid_thws):
|
64 |
+
start_idx=0
|
65 |
+
reshaped_features = []
|
66 |
+
for thw_group in grid_thws:
|
67 |
+
for tensor_thw in thw_group:
|
68 |
+
_, H, W = tensor_thw.squeeze().tolist()
|
69 |
+
num_elements = H * W
|
70 |
+
|
71 |
+
split_tensor = mm_features_raw[start_idx:start_idx + num_elements].view(H, W, -1)
|
72 |
+
reshaped_features.append(split_tensor)
|
73 |
+
|
74 |
+
start_idx += num_elements
|
75 |
+
assert len(mm_features_raw)==start_idx
|
76 |
+
return reshaped_features
|
77 |
+
|
78 |
+
def annToMask(mask_ann, h=None, w=None):
|
79 |
+
if isinstance(mask_ann, list):
|
80 |
+
rles = maskUtils.frPyObjects(mask_ann, h, w)
|
81 |
+
rle = maskUtils.merge(rles)
|
82 |
+
elif isinstance(mask_ann['counts'], list):
|
83 |
+
# uncompressed RLE
|
84 |
+
rle = maskUtils.frPyObjects(mask_ann, h, w)
|
85 |
+
else:
|
86 |
+
# rle
|
87 |
+
rle = mask_ann
|
88 |
+
mask = maskUtils.decode(rle)
|
89 |
+
return mask
|
90 |
+
|
91 |
+
def chunk_list(input_list, chunk_size):
|
92 |
+
return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
|
93 |
+
|
94 |
+
|
95 |
+
def load_image_from_base64(image):
|
96 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
97 |
+
|
98 |
+
|
99 |
+
def expand2square(pil_img, background_color):
|
100 |
+
width, height = pil_img.size
|
101 |
+
if width == height:
|
102 |
+
return pil_img
|
103 |
+
elif width > height:
|
104 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
105 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
106 |
+
return result
|
107 |
+
else:
|
108 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
109 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
110 |
+
return result
|
111 |
+
|
112 |
+
|
113 |
+
def grid_divide(image, cell_size):
|
114 |
+
"""
|
115 |
+
Divides an image into grid of a specified size.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
image (PIL.Image.Image): The input image.
|
119 |
+
cell_size (int): The size of each cell.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
123 |
+
"""
|
124 |
+
grid = []
|
125 |
+
width, height = image.size
|
126 |
+
for i in range(0, height, cell_size):
|
127 |
+
row = []
|
128 |
+
for j in range(0, width, cell_size):
|
129 |
+
box = (j, i, j + cell_size, i + cell_size)
|
130 |
+
row.append(image.crop(box))
|
131 |
+
grid.append(row)
|
132 |
+
|
133 |
+
return grid
|
134 |
+
|
135 |
+
|
136 |
+
def load_images(image_path):
|
137 |
+
if isinstance(image_path, str) and os.path.isfile(image_path):
|
138 |
+
images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
|
139 |
+
# images = [Image.open(image_path).convert('RGB')]
|
140 |
+
elif isinstance(image_path, str) and os.path.isdir(image_path):
|
141 |
+
images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
|
142 |
+
# images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
|
143 |
+
elif isinstance(image_path, list) and isinstance(image_path[0], str):
|
144 |
+
images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
|
145 |
+
# images = [Image.open(f).convert('RGB') for f in image_path]
|
146 |
+
elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
|
147 |
+
images = image_path
|
148 |
+
elif isinstance(image_path, Image.Image):
|
149 |
+
images = [image_path]
|
150 |
+
else:
|
151 |
+
print('image_path: ', image_path)
|
152 |
+
raise ValueError(f"Unsupported image path type: {image_path}")
|
153 |
+
|
154 |
+
return images
|
155 |
+
|
156 |
+
|
157 |
+
def process_pad_image(image, padding_value=(0, 0, 0)):
|
158 |
+
image = expand2square(image, padding_value)
|
159 |
+
|
160 |
+
return [image]
|
161 |
+
|
162 |
+
|
163 |
+
def find_closest_aspect_ratio(src_ratio, tgt_ratios, ori_size, tgt_size):
|
164 |
+
best_ratio_diff = float('inf')
|
165 |
+
best_ratio = (1, 1)
|
166 |
+
area = ori_size[0] * ori_size[1]
|
167 |
+
for ratio in tgt_ratios:
|
168 |
+
tgt_ratio = ratio[0] / ratio[1]
|
169 |
+
ratio_diff = abs(src_ratio - tgt_ratio)
|
170 |
+
if ratio_diff < best_ratio_diff:
|
171 |
+
best_ratio_diff = ratio_diff
|
172 |
+
best_ratio = ratio
|
173 |
+
elif ratio_diff == best_ratio_diff:
|
174 |
+
if area > 0.5 * tgt_size[0] * tgt_size[1] * ratio[0] * ratio[1]:
|
175 |
+
best_ratio = ratio
|
176 |
+
|
177 |
+
return best_ratio
|
178 |
+
|
179 |
+
|
180 |
+
def process_dynamic_image(image, image_size=384, use_thumbnail=True):
|
181 |
+
# Grid Params:
|
182 |
+
min_num = 1
|
183 |
+
max_num = 12
|
184 |
+
|
185 |
+
if isinstance(image_size, int):
|
186 |
+
image_size = (image_size, image_size)
|
187 |
+
|
188 |
+
ori_size = image.size
|
189 |
+
aspect_ratio = ori_size[0] / ori_size[1]
|
190 |
+
|
191 |
+
# calculate the existing image aspect ratio
|
192 |
+
tgt_ratios = []
|
193 |
+
for n in range(min_num, max_num + 1):
|
194 |
+
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
|
195 |
+
tgt_ratios = set(tgt_ratios)
|
196 |
+
tgt_ratios = sorted(tgt_ratios, key=lambda x: x[0] * x[1])
|
197 |
+
|
198 |
+
# find the closest aspect ratio to the target
|
199 |
+
tgt_ratio = find_closest_aspect_ratio(aspect_ratio, tgt_ratios, ori_size, image_size)
|
200 |
+
|
201 |
+
# resize the image to the target size
|
202 |
+
tgt_width = image_size[0] * tgt_ratio[0]
|
203 |
+
tgt_height = image_size[1] * tgt_ratio[1]
|
204 |
+
resized_img = image.resize((tgt_width, tgt_height))
|
205 |
+
|
206 |
+
# NOTE: internvl2 style split the image into one column grids
|
207 |
+
# num_grids = tgt_ratio[0] * tgt_ratio[1]
|
208 |
+
# grid_images = []
|
209 |
+
# for i in range(num_grids):
|
210 |
+
# box = (
|
211 |
+
# (i % tgt_ratio[0]) * image_size[0],
|
212 |
+
# (i // tgt_ratio[0]) * image_size[1],
|
213 |
+
# (i % tgt_ratio[0] + 1) * image_size[0],
|
214 |
+
# (i // tgt_ratio[0] + 1) * image_size[1],
|
215 |
+
# )
|
216 |
+
# # crop out the grid image
|
217 |
+
# grid_images.append(resized_img.crop(box))
|
218 |
+
# assert len(grid_images) == num_grids
|
219 |
+
# grid_images = [grid_images]
|
220 |
+
|
221 |
+
# NOTE: eager implementation
|
222 |
+
# num_grids = tgt_ratio[0] * tgt_ratio[1]
|
223 |
+
# sub_grid_images = []
|
224 |
+
# tmp_grid_images = []
|
225 |
+
# for i in range(num_grids):
|
226 |
+
# box = (
|
227 |
+
# (i % tgt_ratio[0]) * image_size[0],
|
228 |
+
# (i // tgt_ratio[0]) * image_size[1],
|
229 |
+
# (i % tgt_ratio[0] + 1) * image_size[0],
|
230 |
+
# (i // tgt_ratio[0] + 1) * image_size[1],
|
231 |
+
# )
|
232 |
+
# tmp_grid_images.append(resized_img.crop(box))
|
233 |
+
|
234 |
+
# if (i + 1) % tgt_ratio[0] == 0:
|
235 |
+
# sub_grid_images.append(tmp_grid_images)
|
236 |
+
# tmp_grid_images = []
|
237 |
+
|
238 |
+
image_grid = grid_divide(resized_img, image_size[0])
|
239 |
+
|
240 |
+
if use_thumbnail:
|
241 |
+
thumbnail_img = image.resize((image_size[0], image_size[1]))
|
242 |
+
image_grid = [[thumbnail_img]] + image_grid
|
243 |
+
|
244 |
+
return image_grid
|
245 |
+
|
246 |
+
|
247 |
+
def process_highres_image(image_path, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
|
248 |
+
# Grid Params:
|
249 |
+
grid_width = [1, 2, 3]
|
250 |
+
grid_width_real = [x * image_size for x in grid_width]
|
251 |
+
|
252 |
+
longest_side = max(image.size)
|
253 |
+
fit_grid_width_real = [x for x in grid_width_real if x >= longest_side]
|
254 |
+
if len(fit_grid_width_real) == 0:
|
255 |
+
select_size = max(grid_width_real)
|
256 |
+
else:
|
257 |
+
select_size = min(fit_grid_width_real)
|
258 |
+
|
259 |
+
image_padded = expand2square(image, padding_value)
|
260 |
+
image_padded = image_padded.resize((select_size, select_size))
|
261 |
+
image_grid = grid_divide(image_padded, image_size)
|
262 |
+
|
263 |
+
if use_thumbnail:
|
264 |
+
thumbnail_img = image.resize((image_size, image_size))
|
265 |
+
image_grid = [[thumbnail_img]] + image_grid
|
266 |
+
|
267 |
+
return image_grid
|
268 |
+
|
269 |
+
|
270 |
+
def select_best_resolution(original_size, possible_resolutions):
|
271 |
+
"""
|
272 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
276 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
tuple: The best fit resolution in the format (width, height).
|
280 |
+
"""
|
281 |
+
original_width, original_height = original_size
|
282 |
+
best_fit = None
|
283 |
+
max_effective_resolution = 0
|
284 |
+
min_wasted_resolution = float('inf')
|
285 |
+
|
286 |
+
for width, height in possible_resolutions:
|
287 |
+
scale = min(width / original_width, height / original_height)
|
288 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
289 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
290 |
+
wasted_resolution = (width * height) - effective_resolution
|
291 |
+
|
292 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
293 |
+
max_effective_resolution = effective_resolution
|
294 |
+
min_wasted_resolution = wasted_resolution
|
295 |
+
best_fit = (width, height)
|
296 |
+
|
297 |
+
return best_fit
|
298 |
+
|
299 |
+
|
300 |
+
def process_anyres_image(image, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
|
301 |
+
"""
|
302 |
+
Process an image with variable resolutions.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
image (PIL.Image.Image): The input image to be processed.
|
306 |
+
processor: The image processor object.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
torch.Tensor: A tensor containing the processed image patches.
|
310 |
+
"""
|
311 |
+
# Grid Params:
|
312 |
+
possible_grids = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
|
313 |
+
possible_resolutions = [(x * image_size, y * image_size) for x, y in possible_grids]
|
314 |
+
|
315 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
316 |
+
|
317 |
+
# resize and padding image
|
318 |
+
nw, nh = best_resolution
|
319 |
+
ow, oh = image.size
|
320 |
+
|
321 |
+
scale_factor = min(nw / ow, nh / oh)
|
322 |
+
new_size = (int(ow * scale_factor), int(oh * scale_factor))
|
323 |
+
|
324 |
+
image_padded = Image.new("RGB", (nw, nh), padding_value)
|
325 |
+
image_padded.paste(image.resize(new_size), ((nw - new_size[0]) // 2, (nh - new_size[1]) // 2))
|
326 |
+
|
327 |
+
image_grid = grid_divide(image_padded, image_size)
|
328 |
+
|
329 |
+
if use_thumbnail:
|
330 |
+
thumbnail_img = image.resize((image_size, image_size))
|
331 |
+
image_grid = [[thumbnail_img]] + image_grid
|
332 |
+
|
333 |
+
return image_grid
|
334 |
+
|
335 |
+
|
336 |
+
def process_adares_image(image_path, image_size=384, use_thumbnail=True):
|
337 |
+
# Grid Params:
|
338 |
+
min_num = 1
|
339 |
+
max_num = 12
|
340 |
+
|
341 |
+
if isinstance(image_size, int):
|
342 |
+
image_size = (image_size, image_size)
|
343 |
+
|
344 |
+
ori_size = image.size
|
345 |
+
aspect_ratio = ori_size[0] / ori_size[1]
|
346 |
+
|
347 |
+
# calculate the existing image aspect ratio
|
348 |
+
tgt_ratios = []
|
349 |
+
for n in range(min_num, max_num + 1):
|
350 |
+
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
|
351 |
+
tgt_ratios = set(tgt_ratios)
|
352 |
+
possible_resolutions = [(x * image_size[0], y * image_size[1]) for x, y in tgt_ratios]
|
353 |
+
|
354 |
+
# find the most possible resolution
|
355 |
+
best_resolution = select_best_resolution(ori_size, possible_resolutions)
|
356 |
+
|
357 |
+
# resize the image to the target size
|
358 |
+
resized_img = image.resize((best_resolution[0], best_resolution[1]))
|
359 |
+
|
360 |
+
image_grid = grid_divide(resized_img, image_size[0])
|
361 |
+
|
362 |
+
if use_thumbnail:
|
363 |
+
thumbnail_img = image.resize((image_size[0], image_size[1]))
|
364 |
+
image_grid = [[thumbnail_img]] + image_grid
|
365 |
+
|
366 |
+
return image_grid
|
367 |
+
|
368 |
+
|
369 |
+
def process_images(image_path, processor, aspect_ratio='pad', image_size=384, use_thumbnail=True):
|
370 |
+
images = load_images(image_path)
|
371 |
+
|
372 |
+
padding_value = tuple(int(x*255) for x in processor.image_mean)
|
373 |
+
|
374 |
+
image_grids = []
|
375 |
+
for image in images:
|
376 |
+
if aspect_ratio == 'pad':
|
377 |
+
image_grid = process_pad_image(image, padding_value=padding_value)
|
378 |
+
elif aspect_ratio == 'dynamic':
|
379 |
+
image_grid = process_dynamic_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
|
380 |
+
elif aspect_ratio == 'highres':
|
381 |
+
image_grid = process_highres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
|
382 |
+
elif aspect_ratio == 'anyres':
|
383 |
+
image_grid = process_anyres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
|
384 |
+
elif aspect_ratio == 'adares':
|
385 |
+
image_grid = process_adares_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
|
386 |
+
else:
|
387 |
+
image_grid = [image]
|
388 |
+
|
389 |
+
image_grid = [processor.preprocess(image_row, return_tensors='pt', num_images=len(images)) for image_row in image_grid]
|
390 |
+
image_grids.append(image_grid)
|
391 |
+
|
392 |
+
return image_grids
|
393 |
+
|
394 |
+
|
395 |
+
def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None):
|
396 |
+
if mode == 'uniform':
|
397 |
+
assert num_frames is not None, "Number of frames must be provided for uniform sampling."
|
398 |
+
if duration <= num_frames:
|
399 |
+
return np.arange(duration).astype(int)
|
400 |
+
# NOTE: v1 version
|
401 |
+
# Calculate the size of each segment from which a frame will be extracted
|
402 |
+
# if duration <= num_frames:
|
403 |
+
# return np.arange(duration).astype(int)
|
404 |
+
# seg_size = float(duration - 1) / num_frames
|
405 |
+
|
406 |
+
# frame_ids = []
|
407 |
+
# for i in range(num_frames):
|
408 |
+
# # Calculate the start and end indices of each segment
|
409 |
+
# start = seg_size * i
|
410 |
+
# end = seg_size * (i + 1)
|
411 |
+
# # Append the middle index of the segment to the list
|
412 |
+
# frame_ids.append((start + end) / 2)
|
413 |
+
|
414 |
+
# return np.round(np.array(frame_ids) + 1e-6).astype(int)
|
415 |
+
# NOTE: v0 version
|
416 |
+
return np.linspace(0, duration-1, num_frames, dtype=int)
|
417 |
+
elif mode == 'fps':
|
418 |
+
assert vid_fps is not None, "FPS must be provided for FPS sampling."
|
419 |
+
fps = fps if fps is not None else NUM_FRAMES_PER_SECOND
|
420 |
+
segment_len = min(vid_fps // fps, duration)
|
421 |
+
return np.arange(segment_len // 2, duration, segment_len, dtype=int)
|
422 |
+
else:
|
423 |
+
raise ImportError(f'Unsupported frame sampling mode: {mode}')
|
424 |
+
|
425 |
+
|
426 |
+
def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=None, temporal_factor=1, frame_ids=None):
|
427 |
+
if s is not None and e is not None:
|
428 |
+
s = s if s >= 0. else 0.
|
429 |
+
e = e if e >= 0. else 0.
|
430 |
+
if s > e:
|
431 |
+
s, e = e, s
|
432 |
+
elif s == e:
|
433 |
+
e = s + 1
|
434 |
+
|
435 |
+
# 1. Loading Video
|
436 |
+
if os.path.isdir(video_path):
|
437 |
+
frame_files = sorted(os.listdir(video_path))
|
438 |
+
|
439 |
+
vid_fps = 3
|
440 |
+
num_frames_of_video = len(frame_files)
|
441 |
+
elif video_path.endswith('.gif'):
|
442 |
+
gif_reader = imageio.get_reader(video_path)
|
443 |
+
|
444 |
+
vid_fps = 25
|
445 |
+
num_frames_of_video = len(gif_reader)
|
446 |
+
else:
|
447 |
+
vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
|
448 |
+
# vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
449 |
+
|
450 |
+
vid_fps = vreader.get_avg_fps()
|
451 |
+
num_frames_of_video = len(vreader)
|
452 |
+
|
453 |
+
# 2. Determine frame range & Calculate frame indices
|
454 |
+
f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
|
455 |
+
f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
|
456 |
+
frame_indices = list(range(f_start, f_end + 1))
|
457 |
+
|
458 |
+
duration = len(frame_indices)
|
459 |
+
# 3. Sampling frame indices
|
460 |
+
max_frames = max_frames if max_frames is not None else MAX_FRAMES
|
461 |
+
if fps is not None and duration / vid_fps < max_frames:
|
462 |
+
try:
|
463 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps)]
|
464 |
+
except:
|
465 |
+
print('sampled_frame_indices error: ', )
|
466 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
|
467 |
+
|
468 |
+
else:
|
469 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
|
470 |
+
|
471 |
+
# 4. Acquire frame data
|
472 |
+
if os.path.isdir(video_path):
|
473 |
+
frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices]
|
474 |
+
elif video_path.endswith('.gif'):
|
475 |
+
frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
|
476 |
+
else:
|
477 |
+
frames = vreader.get_batch(sampled_frame_indices).asnumpy()
|
478 |
+
|
479 |
+
# frames = frames.transpose(0, 3, 1, 2)
|
480 |
+
timestamps = [x / vid_fps for x in sampled_frame_indices]
|
481 |
+
|
482 |
+
if temporal_factor > 1:
|
483 |
+
pad_length = temporal_factor - len(frames) % temporal_factor
|
484 |
+
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
|
485 |
+
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
|
486 |
+
|
487 |
+
# NOTE: pad the video with black frames
|
488 |
+
# while num_frames is not None and len(video_data) < num_frames:
|
489 |
+
# video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
|
490 |
+
|
491 |
+
additional_frames = []
|
492 |
+
if frame_ids is not None:
|
493 |
+
if os.path.isdir(video_path):
|
494 |
+
additional_frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in frame_ids]
|
495 |
+
elif video_path.endswith('.gif'):
|
496 |
+
additional_frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in frame_ids]
|
497 |
+
else:
|
498 |
+
additional_frames = vreader.get_batch(frame_ids).asnumpy()
|
499 |
+
|
500 |
+
return frames, timestamps, additional_frames
|
501 |
+
|
502 |
+
|
503 |
+
def load_video(
|
504 |
+
video_path: str,
|
505 |
+
start_time: Optional[float] = None,
|
506 |
+
end_time: Optional[float] = None,
|
507 |
+
fps: Optional[float] = None,
|
508 |
+
max_frames: Optional[float] = None,
|
509 |
+
size: Optional[int] = None,
|
510 |
+
size_divisible: int = 1,
|
511 |
+
precise_time: bool = False,
|
512 |
+
verbose: bool = False,
|
513 |
+
temporal_factor: int = 1,
|
514 |
+
frame_ids = None
|
515 |
+
):
|
516 |
+
"""
|
517 |
+
Load and process a video file and return the frames and the timestamps of each frame.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
video_path (str): Path to the video file.
|
521 |
+
start_time (float, optional): Start time in seconds. Defaults to None.
|
522 |
+
end_time (float, optional): End time in seconds. Defaults to None.
|
523 |
+
fps (float, optional): Frames per second. Defaults to None.
|
524 |
+
num_frames (float, optional): Number of frames to sample. Defaults to None.
|
525 |
+
size (int, optional): Size of the shortest side. Defaults to None.
|
526 |
+
size_divisible (int, optional): Size divisible by this number. Defaults to 1.
|
527 |
+
precise_time (bool, optional): Whether to use precise time. Defaults to False.
|
528 |
+
verbose (bool, optional): Print ffmpeg output. Defaults to False.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
frames (List[PIL.Image]): List of frames.
|
532 |
+
timestamps (List[float]): List of timestamps.
|
533 |
+
"""
|
534 |
+
if start_time is not None and end_time is not None and end_time - start_time < 1:
|
535 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
|
536 |
+
if os.path.isdir(video_path):
|
537 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
|
538 |
+
if video_path.endswith('.gif'):
|
539 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
|
540 |
+
probe = ffmpeg.probe(video_path)
|
541 |
+
duration = float(probe['format']['duration'])
|
542 |
+
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
|
543 |
+
w, h = int(video_stream['width']), int(video_stream['height'])
|
544 |
+
|
545 |
+
kwargs, input_kwargs, output_kwargs = {}, {}, {}
|
546 |
+
do_trim = start_time is not None or end_time is not None
|
547 |
+
if start_time is not None:
|
548 |
+
new_start_time = max(float(video_stream['start_time']), start_time)
|
549 |
+
duration -= new_start_time - start_time
|
550 |
+
start_time = new_start_time
|
551 |
+
else:
|
552 |
+
start_time = float(video_stream['start_time'])
|
553 |
+
if end_time is not None:
|
554 |
+
duration = min(duration, end_time - start_time)
|
555 |
+
else:
|
556 |
+
duration = duration
|
557 |
+
if do_trim:
|
558 |
+
kwargs = {'ss': start_time, 't': duration}
|
559 |
+
if precise_time:
|
560 |
+
output_kwargs.update(kwargs)
|
561 |
+
else:
|
562 |
+
input_kwargs.update(kwargs)
|
563 |
+
|
564 |
+
if size is not None:
|
565 |
+
scale_factor = size / min(w, h)
|
566 |
+
new_w, new_h = round(w * scale_factor), round(h * scale_factor)
|
567 |
+
else:
|
568 |
+
new_w, new_h = w, h
|
569 |
+
new_w = new_w // size_divisible * size_divisible
|
570 |
+
new_h = new_h // size_divisible * size_divisible
|
571 |
+
|
572 |
+
# NOTE: It may result in unexpected number of frames in ffmpeg
|
573 |
+
# if calculate the fps directly according to max_frames
|
574 |
+
# NOTE: the below lines may hurt the performance
|
575 |
+
# if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
|
576 |
+
# fps = max_frames / duration * 2
|
577 |
+
|
578 |
+
stream = ffmpeg.input(video_path, **input_kwargs)
|
579 |
+
if fps is not None:
|
580 |
+
stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
|
581 |
+
if new_w != w or new_h != h:
|
582 |
+
stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
|
583 |
+
stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
|
584 |
+
out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
|
585 |
+
|
586 |
+
frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
|
587 |
+
|
588 |
+
if fps is not None:
|
589 |
+
timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
|
590 |
+
else:
|
591 |
+
timestamps = np.linspace(start_time, start_time + duration, len(frames))
|
592 |
+
|
593 |
+
max_frames = max_frames if max_frames is not None else MAX_FRAMES
|
594 |
+
if max_frames is not None and len(frames) > max_frames:
|
595 |
+
indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
|
596 |
+
frames = frames[indices]
|
597 |
+
timestamps = [timestamps[i] for i in indices]
|
598 |
+
|
599 |
+
if temporal_factor > 1:
|
600 |
+
pad_length = temporal_factor - len(frames) % temporal_factor
|
601 |
+
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
|
602 |
+
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
|
603 |
+
|
604 |
+
frames = [frame for frame in frames]
|
605 |
+
additional_frames = []
|
606 |
+
# print('frame_ids', frame_ids)
|
607 |
+
if frame_ids is not None:
|
608 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
609 |
+
additional_frames = vr.get_batch(frame_ids).asnumpy()
|
610 |
+
|
611 |
+
return frames, timestamps, additional_frames
|
612 |
+
|
613 |
+
|
614 |
+
def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=None):
|
615 |
+
fps = 1 if num_frames is None else None
|
616 |
+
# FFmpeg
|
617 |
+
frames, timestamps = load_video(video_path, s, e, fps=fps, max_frames=num_frames)
|
618 |
+
# Decord
|
619 |
+
# frames, timestamps = load_video_from_ids(video_path, s, e, fps=fps, max_frames=num_frames)
|
620 |
+
|
621 |
+
assert len(frames) == len(timestamps), "Number of frames and timestamps must match."
|
622 |
+
|
623 |
+
if aspect_ratio == 'pad':
|
624 |
+
frames = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in frames]
|
625 |
+
|
626 |
+
if aspect_ratio == 'qwen2vl':
|
627 |
+
frames = [processor.preprocess(frame, return_tensors='pt', image_num=len(frames)) for frame in frames]
|
628 |
+
grid_frames = [frames]
|
629 |
+
else:
|
630 |
+
frames = processor.preprocess(frames, return_tensors='pt', image_num=len(frames))
|
631 |
+
grid_frames = [[frames]]
|
632 |
+
|
633 |
+
return grid_frames, timestamps
|
634 |
+
|
635 |
+
|
636 |
+
def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
|
637 |
+
"""Tokenize text and multimodal tag to input_ids.
|
638 |
+
|
639 |
+
Args:
|
640 |
+
prompt (str): Text prompt (w/ multimodal tag), e.g., '<video>\nDescribe the video.'
|
641 |
+
tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
|
642 |
+
multimodal_token (int): Token index corresponding to the multimodal tag.
|
643 |
+
"""
|
644 |
+
multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
|
645 |
+
if multimodal_token_index is None:
|
646 |
+
input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
|
647 |
+
else:
|
648 |
+
prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for idx, chunk in enumerate(prompt.split(multimodal_token))]
|
649 |
+
|
650 |
+
input_ids = []
|
651 |
+
for i in range(1, 2 * len(prompt_chunks)):
|
652 |
+
if i % 2 == 1:
|
653 |
+
input_ids.extend(prompt_chunks[i // 2])
|
654 |
+
else:
|
655 |
+
input_ids.append(multimodal_token_index)
|
656 |
+
|
657 |
+
if return_tensors is not None:
|
658 |
+
if return_tensors == 'pt':
|
659 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
660 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
661 |
+
return input_ids
|
662 |
+
|
663 |
+
|
664 |
+
def get_model_name_from_path(model_path):
|
665 |
+
model_path = model_path.strip("/")
|
666 |
+
model_paths = model_path.split("/")
|
667 |
+
if model_paths[-1].startswith('checkpoint-'):
|
668 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
669 |
+
else:
|
670 |
+
return model_paths[-1]
|
671 |
+
|
672 |
+
|
673 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
674 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
675 |
+
self.keywords = keywords
|
676 |
+
self.keyword_ids = []
|
677 |
+
self.max_keyword_len = 0
|
678 |
+
for keyword in keywords:
|
679 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
680 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
681 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
682 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
683 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
684 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
685 |
+
self.tokenizer = tokenizer
|
686 |
+
self.start_len = input_ids.shape[1]
|
687 |
+
|
688 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
689 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
690 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
691 |
+
for keyword_id in self.keyword_ids:
|
692 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
693 |
+
return True
|
694 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
695 |
+
for keyword in self.keywords:
|
696 |
+
if keyword in outputs:
|
697 |
+
return True
|
698 |
+
return False
|
699 |
+
|
700 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
701 |
+
outputs = []
|
702 |
+
for i in range(output_ids.shape[0]):
|
703 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
704 |
+
return all(outputs)
|
videollama3/model/__init__.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
2 |
+
# Copyright 2023 Haotian Liu
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import warnings
|
19 |
+
import shutil
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from transformers import PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
23 |
+
|
24 |
+
from .projector import load_mm_projector
|
25 |
+
from .videollama3_qwen2 import Videollama3Qwen2ForCausalLM, Videollama3Qwen2Config
|
26 |
+
|
27 |
+
|
28 |
+
VLLMs = {
|
29 |
+
"videollama3_qwen2": Videollama3Qwen2ForCausalLM,
|
30 |
+
}
|
31 |
+
|
32 |
+
VLLMConfigs = {
|
33 |
+
"videollama3_qwen2": Videollama3Qwen2Config,
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", **kwargs):
|
38 |
+
if 'token' in kwargs:
|
39 |
+
token = kwargs['token']
|
40 |
+
else:
|
41 |
+
token = None
|
42 |
+
|
43 |
+
# NOTE: auto device_map by default
|
44 |
+
# if want to put model into a single device, you can set device_map={"": "cuda:0"}
|
45 |
+
kwargs = {"device_map": device_map, **kwargs}
|
46 |
+
|
47 |
+
config = AutoConfig.from_pretrained(model_path)
|
48 |
+
config._attn_implementation = kwargs.pop('attn_implementation', "flash_attention_2") # default to flash_attention_2
|
49 |
+
|
50 |
+
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else kwargs.pop('torch_dtype', torch.float16)
|
51 |
+
|
52 |
+
if load_8bit:
|
53 |
+
kwargs['load_in_8bit'] = True
|
54 |
+
elif load_4bit:
|
55 |
+
# NOTE: High-version Transformers will report: """ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time."""
|
56 |
+
# kwargs['load_in_4bit'] = True
|
57 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
58 |
+
load_in_4bit=True,
|
59 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
60 |
+
bnb_4bit_use_double_quant=True,
|
61 |
+
bnb_4bit_quant_type='nf4'
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
kwargs['torch_dtype'] = torch_dtype
|
65 |
+
|
66 |
+
# judge model type
|
67 |
+
model_type = config.model_type if hasattr(config, "model_type") else kwargs.pop('model_type', "videollama3_qwen2")
|
68 |
+
|
69 |
+
# judge pretrain/finetune
|
70 |
+
is_alignment = getattr(config, "tune_mm_mlp_adapter", False) or getattr(config, "is_alignment", False)
|
71 |
+
|
72 |
+
# NOTE: lora/qlora model loading
|
73 |
+
if 'lora' in model_name.lower() or 'qlora' in model_name.lower():
|
74 |
+
cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
|
75 |
+
# NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
|
76 |
+
# cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
|
77 |
+
model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
|
78 |
+
|
79 |
+
# NOTE: remove qlora training quantization config
|
80 |
+
if hasattr(lora_cfg_pretrained, 'quantization_config'):
|
81 |
+
del lora_cfg_pretrained.quantization_config
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
|
83 |
+
print('Loading VideoLLaMA from base model...')
|
84 |
+
|
85 |
+
if 'qwen2' in model_base.lower():
|
86 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
87 |
+
else:
|
88 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
89 |
+
|
90 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
91 |
+
if model.lm_head.weight.shape[0] != token_num:
|
92 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
93 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
94 |
+
|
95 |
+
print('Loading additional VideoLLaMA weights...')
|
96 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
97 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
98 |
+
else:
|
99 |
+
# this is probably from HF Hub
|
100 |
+
from huggingface_hub import hf_hub_download
|
101 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
102 |
+
cache_file = hf_hub_download(
|
103 |
+
repo_id=repo_id,
|
104 |
+
filename=filename,
|
105 |
+
subfolder=subfolder)
|
106 |
+
return torch.load(cache_file, map_location='cpu')
|
107 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
108 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
109 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
110 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
111 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
112 |
+
|
113 |
+
from peft import PeftModel
|
114 |
+
print('Loading LoRA weights...')
|
115 |
+
model = PeftModel.from_pretrained(model, model_path)
|
116 |
+
print('Merging LoRA weights...')
|
117 |
+
model = model.merge_and_unload()
|
118 |
+
print('Model is loaded...')
|
119 |
+
elif model_base is not None or '-base' in model_name.lower() or is_alignment:
|
120 |
+
# NOTE: Base/Pretrain model loading
|
121 |
+
print('Loading VideoLLaMA 2 from base model...')
|
122 |
+
cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
|
123 |
+
# NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
|
124 |
+
# cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
|
125 |
+
model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
|
126 |
+
|
127 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
|
128 |
+
|
129 |
+
if model_type in ['videollama3', 'videollama3_qwen2']:
|
130 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
131 |
+
else:
|
132 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
133 |
+
|
134 |
+
# NOTE; loading vision-language projector
|
135 |
+
# * old codes for loading local mm_projector.bin
|
136 |
+
# mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
137 |
+
# mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
138 |
+
# model.load_state_dict(mm_projector_weights, strict=False)
|
139 |
+
# * new codes which supports loading mm_projector.bin both offline and online
|
140 |
+
mm_projector_weights = load_mm_projector(model_path, token=token)
|
141 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
142 |
+
elif 'videollama' in model_type:
|
143 |
+
# NOTE: SFT model loading
|
144 |
+
print(model_path)
|
145 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
|
146 |
+
|
147 |
+
if model_type in ['videollama3_qwen2']:
|
148 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
|
149 |
+
else:
|
150 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
|
151 |
+
else:
|
152 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, token=token)
|
153 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, **kwargs)
|
154 |
+
|
155 |
+
processor = None
|
156 |
+
|
157 |
+
if "videollama" in model_type:
|
158 |
+
vision_encoder = model.get_vision_encoder()
|
159 |
+
processor = vision_encoder.image_processor
|
160 |
+
|
161 |
+
if hasattr(model.config, "max_sequence_length"):
|
162 |
+
context_len = model.config.max_sequence_length
|
163 |
+
else:
|
164 |
+
context_len = 2048
|
165 |
+
|
166 |
+
return tokenizer, model, processor, context_len
|
videollama3/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (4.04 kB). View file
|
|
videollama3/model/__pycache__/encoder.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
videollama3/model/__pycache__/processor.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
videollama3/model/__pycache__/projector.cpython-310.pyc
ADDED
Binary file (5.11 kB). View file
|
|
videollama3/model/__pycache__/region_encoder.cpython-310.pyc
ADDED
Binary file (3.43 kB). View file
|
|
videollama3/model/__pycache__/videollama3_arch.cpython-310.pyc
ADDED
Binary file (9.74 kB). View file
|
|
videollama3/model/__pycache__/videollama3_qwen2.cpython-310.pyc
ADDED
Binary file (4.2 kB). View file
|
|
videollama3/model/damovl_encoder/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration_damovl_encoder import DAMOVLVisionConfig
|
2 |
+
from .image_processing import DAMOVLImageProcessor
|
3 |
+
from .modeling_damovl_encoder import DAMOVLVisionModel
|
videollama3/model/damovl_encoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (407 Bytes). View file
|
|
videollama3/model/damovl_encoder/__pycache__/configuration_damovl_encoder.cpython-310.pyc
ADDED
Binary file (1.96 kB). View file
|
|
videollama3/model/damovl_encoder/__pycache__/image_processing.cpython-310.pyc
ADDED
Binary file (16.7 kB). View file
|
|
videollama3/model/damovl_encoder/__pycache__/modeling_damovl_encoder.cpython-310.pyc
ADDED
Binary file (16.9 kB). View file
|
|
videollama3/model/damovl_encoder/configuration_damovl_encoder.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Qwen2VL model configuration"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class DAMOVLVisionConfig(PretrainedConfig):
|
28 |
+
model_type = "damovl"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
hidden_size=768,
|
33 |
+
intermediate_size=3072,
|
34 |
+
num_hidden_layers=12,
|
35 |
+
num_attention_heads=12,
|
36 |
+
num_channels=3,
|
37 |
+
patch_size=16,
|
38 |
+
hidden_act="gelu_pytorch_tanh",
|
39 |
+
layer_norm_eps=1e-6,
|
40 |
+
attention_dropout=0.0,
|
41 |
+
spatial_merge_size=1,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
|
46 |
+
self.hidden_size = hidden_size
|
47 |
+
self.intermediate_size = intermediate_size
|
48 |
+
self.num_hidden_layers = num_hidden_layers
|
49 |
+
self.num_attention_heads = num_attention_heads
|
50 |
+
self.num_channels = num_channels
|
51 |
+
self.patch_size = patch_size
|
52 |
+
self.attention_dropout = attention_dropout
|
53 |
+
self.layer_norm_eps = layer_norm_eps
|
54 |
+
self.hidden_act = hidden_act
|
55 |
+
self.spatial_merge_size = spatial_merge_size
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
59 |
+
cls._set_token_in_kwargs(kwargs)
|
60 |
+
|
61 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
62 |
+
|
63 |
+
# config_dict = config_dict["vision_config"]
|
64 |
+
|
65 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
66 |
+
logger.warning(
|
67 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
68 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
69 |
+
)
|
70 |
+
|
71 |
+
return cls.from_dict(config_dict, **kwargs)
|
videollama3/model/damovl_encoder/image_processing.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""Image processor class for Qwen2-VL."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from typing import Dict, List, Optional, Union
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
28 |
+
from transformers.image_transforms import (
|
29 |
+
convert_to_rgb,
|
30 |
+
resize,
|
31 |
+
to_channel_dimension_format,
|
32 |
+
)
|
33 |
+
from transformers.image_utils import (
|
34 |
+
OPENAI_CLIP_MEAN,
|
35 |
+
OPENAI_CLIP_STD,
|
36 |
+
ChannelDimension,
|
37 |
+
ImageInput,
|
38 |
+
PILImageResampling,
|
39 |
+
VideoInput,
|
40 |
+
get_image_size,
|
41 |
+
infer_channel_dimension_format,
|
42 |
+
is_scaled_image,
|
43 |
+
is_valid_image,
|
44 |
+
make_list_of_images,
|
45 |
+
to_numpy_array,
|
46 |
+
valid_images,
|
47 |
+
validate_preprocess_arguments,
|
48 |
+
)
|
49 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
50 |
+
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
if is_vision_available():
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
|
59 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
60 |
+
"""
|
61 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
65 |
+
The input image.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
list: A list of images.
|
69 |
+
"""
|
70 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
71 |
+
return [img for img_list in images for img in img_list]
|
72 |
+
|
73 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
74 |
+
return images
|
75 |
+
|
76 |
+
elif is_valid_image(images):
|
77 |
+
return [images]
|
78 |
+
|
79 |
+
raise ValueError(f"Could not make batched images from {images}")
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
83 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
84 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
85 |
+
return videos
|
86 |
+
|
87 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
88 |
+
if isinstance(videos[0], Image.Image):
|
89 |
+
return [videos]
|
90 |
+
elif len(videos[0].shape) == 4:
|
91 |
+
return [list(video) for video in videos]
|
92 |
+
|
93 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
94 |
+
return [list(videos)]
|
95 |
+
|
96 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
97 |
+
|
98 |
+
|
99 |
+
def smart_resize(
|
100 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
101 |
+
):
|
102 |
+
"""Rescales the image so that the following conditions are met:
|
103 |
+
|
104 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
105 |
+
|
106 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
107 |
+
|
108 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
109 |
+
|
110 |
+
"""
|
111 |
+
if height < factor or width < factor:
|
112 |
+
scale = factor / min(height, width)
|
113 |
+
width = round(scale * width)
|
114 |
+
height = round(scale * height)
|
115 |
+
elif max(height, width) / min(height, width) > 200:
|
116 |
+
raise ValueError(
|
117 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
118 |
+
)
|
119 |
+
h_bar = round(height / factor) * factor
|
120 |
+
w_bar = round(width / factor) * factor
|
121 |
+
if h_bar * w_bar > max_pixels:
|
122 |
+
beta = math.sqrt((height * width) / max_pixels)
|
123 |
+
h_bar = math.floor(height / beta / factor) * factor
|
124 |
+
w_bar = math.floor(width / beta / factor) * factor
|
125 |
+
elif h_bar * w_bar < min_pixels:
|
126 |
+
beta = math.sqrt(min_pixels / (height * width))
|
127 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
128 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
129 |
+
return h_bar, w_bar
|
130 |
+
|
131 |
+
|
132 |
+
class DAMOVLImageProcessor(BaseImageProcessor):
|
133 |
+
r"""
|
134 |
+
Constructs a DAMOVL image processor that dynamically resizes images based on the original images.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
138 |
+
Whether to resize the image's (height, width) dimensions.
|
139 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
140 |
+
Resampling filter to use when resizing the image.
|
141 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
143 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
144 |
+
Scale factor to use if rescaling the image.
|
145 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
146 |
+
Whether to normalize the image.
|
147 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
148 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
149 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
150 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
151 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
152 |
+
Whether to convert the image to RGB.
|
153 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
154 |
+
The min pixels of the image to resize the image.
|
155 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
156 |
+
The max pixels of the image to resize the image.
|
157 |
+
patch_size (`int`, *optional*, defaults to 14):
|
158 |
+
The spacial patch size of the vision encoder.
|
159 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
160 |
+
The temporal patch size of the vision encoder.
|
161 |
+
merge_size (`int`, *optional*, defaults to 2):
|
162 |
+
The merge size of the vision encoder to llm encoder.
|
163 |
+
"""
|
164 |
+
|
165 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
do_resize: bool = True,
|
170 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
171 |
+
do_rescale: bool = True,
|
172 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
173 |
+
do_normalize: bool = True,
|
174 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
175 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
176 |
+
do_convert_rgb: bool = True,
|
177 |
+
min_pixels: int = 56 * 56,
|
178 |
+
max_pixels: int = 14 * 14 * 9477,
|
179 |
+
patch_size: int = 14,
|
180 |
+
merge_size: int = 1,
|
181 |
+
**kwargs,
|
182 |
+
) -> None:
|
183 |
+
super().__init__(**kwargs)
|
184 |
+
self.do_resize = do_resize
|
185 |
+
self.resample = resample
|
186 |
+
self.do_rescale = do_rescale
|
187 |
+
self.rescale_factor = rescale_factor
|
188 |
+
self.do_normalize = do_normalize
|
189 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
190 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
191 |
+
self.min_pixels = min_pixels
|
192 |
+
self.max_pixels = max_pixels
|
193 |
+
self.patch_size = patch_size
|
194 |
+
self.merge_size = merge_size
|
195 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
196 |
+
self.do_convert_rgb = do_convert_rgb
|
197 |
+
|
198 |
+
self.temporal_patch_size = 1
|
199 |
+
|
200 |
+
def _preprocess(
|
201 |
+
self,
|
202 |
+
images: Union[ImageInput, VideoInput],
|
203 |
+
do_resize: bool = None,
|
204 |
+
resample: PILImageResampling = None,
|
205 |
+
do_rescale: bool = None,
|
206 |
+
rescale_factor: float = None,
|
207 |
+
do_normalize: bool = None,
|
208 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
209 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
210 |
+
do_convert_rgb: bool = None,
|
211 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
212 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
213 |
+
num_images: Optional[int] = 1,
|
214 |
+
image_downsampling: Optional[int] = None,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
images (`ImageInput`):
|
221 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
222 |
+
vision_info (`List[Dict]`, *optional*):
|
223 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
224 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
225 |
+
Whether to resize the image.
|
226 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
227 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
228 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
229 |
+
Whether to rescale the image.
|
230 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
231 |
+
Scale factor to use if rescaling the image.
|
232 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
233 |
+
Whether to normalize the image.
|
234 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
235 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
236 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
237 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
238 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
239 |
+
Whether to convert the image to RGB.
|
240 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
241 |
+
The channel dimension format for the output image. Can be one of:
|
242 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
243 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
244 |
+
- Unset: Use the channel dimension format of the input image.
|
245 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
246 |
+
The channel dimension format for the input image. Can be one of:
|
247 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
248 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
249 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
250 |
+
"""
|
251 |
+
images = make_list_of_images(images)
|
252 |
+
|
253 |
+
if do_convert_rgb:
|
254 |
+
images = [convert_to_rgb(image) for image in images]
|
255 |
+
|
256 |
+
# All transformations expect numpy arrays.
|
257 |
+
images = [to_numpy_array(image) for image in images]
|
258 |
+
|
259 |
+
if is_scaled_image(images[0]) and do_rescale:
|
260 |
+
logger.warning_once(
|
261 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
262 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
263 |
+
)
|
264 |
+
if input_data_format is None:
|
265 |
+
# We assume that all images have the same channel dimension format.
|
266 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
267 |
+
|
268 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
269 |
+
resized_height, resized_width = height, width
|
270 |
+
processed_images = []
|
271 |
+
for image in images:
|
272 |
+
if do_resize:
|
273 |
+
max_pixels = int(self.max_pixels / (self.merge_size / image_downsampling)**2)
|
274 |
+
resized_height, resized_width = smart_resize(
|
275 |
+
height,
|
276 |
+
width,
|
277 |
+
factor=self.patch_size * image_downsampling,
|
278 |
+
min_pixels=self.min_pixels,
|
279 |
+
max_pixels=int(max_pixels // num_images),
|
280 |
+
)
|
281 |
+
image = resize(
|
282 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
283 |
+
)
|
284 |
+
|
285 |
+
if do_rescale:
|
286 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
287 |
+
|
288 |
+
if do_normalize:
|
289 |
+
image = self.normalize(
|
290 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
291 |
+
)
|
292 |
+
|
293 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
294 |
+
processed_images.append(image)
|
295 |
+
|
296 |
+
patches = np.array(processed_images)
|
297 |
+
if data_format == ChannelDimension.LAST:
|
298 |
+
patches = patches.transpose(0, 3, 1, 2)
|
299 |
+
|
300 |
+
channel = patches.shape[1]
|
301 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
302 |
+
patches = patches.reshape(
|
303 |
+
channel,
|
304 |
+
grid_h // image_downsampling,
|
305 |
+
image_downsampling,
|
306 |
+
self.patch_size,
|
307 |
+
grid_w // image_downsampling,
|
308 |
+
image_downsampling,
|
309 |
+
self.patch_size,
|
310 |
+
)
|
311 |
+
patches = patches.transpose(1, 4, 2, 5, 0, 3, 6)
|
312 |
+
flatten_patches = patches.reshape(
|
313 |
+
grid_h * grid_w, channel * self.patch_size * self.patch_size
|
314 |
+
)
|
315 |
+
# print('image_downsampling', image_downsampling)
|
316 |
+
# flatten_patches1 = flatten_patches.reshape(grid_h, grid_w, channel, -1)
|
317 |
+
# from matplotlib import pyplot as plt
|
318 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
319 |
+
# plt.savefig('8.png')
|
320 |
+
|
321 |
+
return flatten_patches, (1, grid_h, grid_w)
|
322 |
+
|
323 |
+
def preprocess(
|
324 |
+
self,
|
325 |
+
images: ImageInput,
|
326 |
+
videos: VideoInput = None,
|
327 |
+
do_resize: bool = None,
|
328 |
+
size: Dict[str, int] = None,
|
329 |
+
resample: PILImageResampling = None,
|
330 |
+
do_rescale: bool = None,
|
331 |
+
rescale_factor: float = None,
|
332 |
+
do_normalize: bool = None,
|
333 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
334 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
335 |
+
do_convert_rgb: bool = None,
|
336 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
337 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
338 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
339 |
+
num_images: Optional[int] = 1,
|
340 |
+
image_downsampling: Optional[int] = None,
|
341 |
+
):
|
342 |
+
"""
|
343 |
+
Args:
|
344 |
+
images (`ImageInput`):
|
345 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
346 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
347 |
+
videos (`VideoInput`):
|
348 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
349 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
350 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
351 |
+
Whether to resize the image.
|
352 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
353 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
354 |
+
the longest edge resized to keep the input aspect ratio.
|
355 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
356 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
357 |
+
has an effect if `do_resize` is set to `True`.
|
358 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
359 |
+
Whether to rescale the image.
|
360 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
361 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
362 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
363 |
+
Whether to normalize the image.
|
364 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
365 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
366 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
367 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
368 |
+
`True`.
|
369 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
370 |
+
Whether to convert the image to RGB.
|
371 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
372 |
+
The type of tensors to return. Can be one of:
|
373 |
+
- Unset: Return a list of `np.ndarray`.
|
374 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
375 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
376 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
377 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
378 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
379 |
+
The channel dimension format for the output image. Can be one of:
|
380 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
381 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
382 |
+
- Unset: Use the channel dimension format of the input image.
|
383 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
384 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
385 |
+
from the input image. Can be one of:
|
386 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
387 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
388 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
389 |
+
|
390 |
+
"""
|
391 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
392 |
+
size = size if size is not None else self.size
|
393 |
+
resample = resample if resample is not None else self.resample
|
394 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
395 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
396 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
397 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
398 |
+
image_std = image_std if image_std is not None else self.image_std
|
399 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
400 |
+
image_downsampling = image_downsampling if image_downsampling is not None else self.merge_size
|
401 |
+
|
402 |
+
if images is not None:
|
403 |
+
images = make_batched_images(images)
|
404 |
+
if videos is not None:
|
405 |
+
videos = make_batched_videos(videos)
|
406 |
+
|
407 |
+
if images is not None and not valid_images(images):
|
408 |
+
raise ValueError(
|
409 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
410 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
411 |
+
)
|
412 |
+
|
413 |
+
validate_preprocess_arguments(
|
414 |
+
rescale_factor=rescale_factor,
|
415 |
+
do_normalize=do_normalize,
|
416 |
+
image_mean=image_mean,
|
417 |
+
image_std=image_std,
|
418 |
+
do_resize=do_resize,
|
419 |
+
size=size,
|
420 |
+
resample=resample,
|
421 |
+
)
|
422 |
+
|
423 |
+
if images is not None:
|
424 |
+
pixel_values, vision_grid_thws = [], []
|
425 |
+
for image in images:
|
426 |
+
patches, image_grid_thw = self._preprocess(
|
427 |
+
image,
|
428 |
+
do_resize=do_resize,
|
429 |
+
resample=resample,
|
430 |
+
do_rescale=do_rescale,
|
431 |
+
rescale_factor=rescale_factor,
|
432 |
+
do_normalize=do_normalize,
|
433 |
+
image_mean=image_mean,
|
434 |
+
image_std=image_std,
|
435 |
+
data_format=data_format,
|
436 |
+
do_convert_rgb=do_convert_rgb,
|
437 |
+
input_data_format=input_data_format,
|
438 |
+
num_images=num_images,
|
439 |
+
image_downsampling=image_downsampling,
|
440 |
+
)
|
441 |
+
pixel_values.extend(patches)
|
442 |
+
vision_grid_thws.append(image_grid_thw)
|
443 |
+
pixel_values = np.array(pixel_values)
|
444 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
445 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
446 |
+
|
447 |
+
assert videos is None, "Not support video for now."
|
448 |
+
# NOTE: not support video for now
|
449 |
+
# if videos is not None:
|
450 |
+
# pixel_values, vision_grid_thws = [], []
|
451 |
+
# for images in videos:
|
452 |
+
# patches, video_grid_thw = self._preprocess(
|
453 |
+
# images,
|
454 |
+
# do_resize=do_resize,
|
455 |
+
# resample=resample,
|
456 |
+
# do_rescale=do_rescale,
|
457 |
+
# rescale_factor=rescale_factor,
|
458 |
+
# do_normalize=do_normalize,
|
459 |
+
# image_mean=image_mean,
|
460 |
+
# image_std=image_std,
|
461 |
+
# data_format=data_format,
|
462 |
+
# do_convert_rgb=do_convert_rgb,
|
463 |
+
# input_data_format=input_data_format,
|
464 |
+
# image_num=image_num,
|
465 |
+
# )
|
466 |
+
# pixel_values.extend(patches)
|
467 |
+
# vision_grid_thws.append(video_grid_thw)
|
468 |
+
# pixel_values = np.array(pixel_values)
|
469 |
+
# vision_grid_thws = np.array(vision_grid_thws)
|
470 |
+
# data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
471 |
+
|
472 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
videollama3/model/damovl_encoder/modeling_damovl_encoder.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch Siglip model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import warnings
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import Any, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
28 |
+
import torch.nn.functional as F
|
29 |
+
|
30 |
+
from transformers.activations import ACT2FN
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import (add_start_docstrings,
|
33 |
+
add_start_docstrings_to_model_forward,
|
34 |
+
is_flash_attn_2_available,
|
35 |
+
is_flash_attn_greater_or_equal_2_10, logging,
|
36 |
+
replace_return_docstrings)
|
37 |
+
from .configuration_damovl_encoder import DAMOVLVisionConfig
|
38 |
+
|
39 |
+
|
40 |
+
if is_flash_attn_2_available():
|
41 |
+
from flash_attn import flash_attn_varlen_func
|
42 |
+
from transformers.modeling_flash_attention_utils import \
|
43 |
+
_flash_attention_forward
|
44 |
+
else:
|
45 |
+
flash_attn_varlen_func = None
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
52 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
53 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
54 |
+
def norm_cdf(x):
|
55 |
+
# Computes standard normal cumulative distribution function
|
56 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
57 |
+
|
58 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
59 |
+
warnings.warn(
|
60 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
61 |
+
"The distribution of values may be incorrect.",
|
62 |
+
stacklevel=2,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Values are generated by using a truncated uniform distribution and
|
66 |
+
# then using the inverse CDF for the normal distribution.
|
67 |
+
# Get upper and lower cdf values
|
68 |
+
l = norm_cdf((a - mean) / std)
|
69 |
+
u = norm_cdf((b - mean) / std)
|
70 |
+
|
71 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
72 |
+
# [2l-1, 2u-1].
|
73 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
74 |
+
|
75 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
76 |
+
# standard normal
|
77 |
+
tensor.erfinv_()
|
78 |
+
|
79 |
+
# Transform to proper mean, std
|
80 |
+
tensor.mul_(std * math.sqrt(2.0))
|
81 |
+
tensor.add_(mean)
|
82 |
+
|
83 |
+
# Clamp to ensure it's in the proper range
|
84 |
+
tensor.clamp_(min=a, max=b)
|
85 |
+
|
86 |
+
|
87 |
+
def trunc_normal_tf_(
|
88 |
+
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
89 |
+
) -> torch.Tensor:
|
90 |
+
"""Fills the input Tensor with values drawn from a truncated
|
91 |
+
normal distribution. The values are effectively drawn from the
|
92 |
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
93 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
94 |
+
the bounds. The method used for generating the random values works
|
95 |
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
96 |
+
|
97 |
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
98 |
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
99 |
+
and the result is subsequently scaled and shifted by the mean and std args.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
tensor: an n-dimensional `torch.Tensor`
|
103 |
+
mean: the mean of the normal distribution
|
104 |
+
std: the standard deviation of the normal distribution
|
105 |
+
a: the minimum cutoff value
|
106 |
+
b: the maximum cutoff value
|
107 |
+
"""
|
108 |
+
with torch.no_grad():
|
109 |
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
110 |
+
tensor.mul_(std).add_(mean)
|
111 |
+
|
112 |
+
|
113 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
114 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
115 |
+
if mode == "fan_in":
|
116 |
+
denom = fan_in
|
117 |
+
elif mode == "fan_out":
|
118 |
+
denom = fan_out
|
119 |
+
elif mode == "fan_avg":
|
120 |
+
denom = (fan_in + fan_out) / 2
|
121 |
+
|
122 |
+
variance = scale / denom
|
123 |
+
|
124 |
+
if distribution == "truncated_normal":
|
125 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
126 |
+
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
127 |
+
elif distribution == "normal":
|
128 |
+
with torch.no_grad():
|
129 |
+
tensor.normal_(std=math.sqrt(variance))
|
130 |
+
elif distribution == "uniform":
|
131 |
+
bound = math.sqrt(3 * variance)
|
132 |
+
with torch.no_grad():
|
133 |
+
tensor.uniform_(-bound, bound)
|
134 |
+
else:
|
135 |
+
raise ValueError(f"invalid distribution {distribution}")
|
136 |
+
|
137 |
+
|
138 |
+
def lecun_normal_(tensor):
|
139 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
140 |
+
|
141 |
+
|
142 |
+
def default_flax_embed_init(tensor):
|
143 |
+
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
144 |
+
|
145 |
+
|
146 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
147 |
+
def rotate_half(x):
|
148 |
+
"""Rotates half the hidden dims of the input."""
|
149 |
+
x1 = x[..., : x.shape[-1] // 2]
|
150 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
151 |
+
return torch.cat((-x2, x1), dim=-1)
|
152 |
+
|
153 |
+
|
154 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
155 |
+
orig_dtype = tensor.dtype
|
156 |
+
tensor = tensor.float()
|
157 |
+
cos = freqs.cos()
|
158 |
+
sin = freqs.sin()
|
159 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
160 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
161 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
162 |
+
output = output.to(orig_dtype)
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
class VisionRotaryEmbedding(nn.Module):
|
167 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
168 |
+
super().__init__()
|
169 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
170 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
171 |
+
|
172 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
173 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
174 |
+
freqs = torch.outer(seq, self.inv_freq)
|
175 |
+
return freqs
|
176 |
+
|
177 |
+
|
178 |
+
class DAMOVLVisionEmbeddings(nn.Module):
|
179 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
180 |
+
super().__init__()
|
181 |
+
self.config = config
|
182 |
+
self.embed_dim = config.hidden_size
|
183 |
+
self.patch_size = config.patch_size
|
184 |
+
|
185 |
+
self.patch_embedding = nn.Conv2d(
|
186 |
+
in_channels=config.num_channels,
|
187 |
+
out_channels=self.embed_dim,
|
188 |
+
kernel_size=self.patch_size,
|
189 |
+
stride=self.patch_size,
|
190 |
+
padding="valid",
|
191 |
+
)
|
192 |
+
|
193 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
194 |
+
hidden_states = hidden_states.view(
|
195 |
+
-1, self.config.num_channels, self.patch_size, self.patch_size
|
196 |
+
)
|
197 |
+
patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
|
198 |
+
# embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
199 |
+
embeddings = patch_embeds.view(-1, self.embed_dim)
|
200 |
+
|
201 |
+
return embeddings
|
202 |
+
|
203 |
+
|
204 |
+
class VisionAttention(nn.Module):
|
205 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
206 |
+
|
207 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
208 |
+
def __init__(self, config):
|
209 |
+
super().__init__()
|
210 |
+
self.config = config
|
211 |
+
self.embed_dim = config.hidden_size
|
212 |
+
self.num_heads = config.num_attention_heads
|
213 |
+
self.head_dim = self.embed_dim // self.num_heads
|
214 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
215 |
+
raise ValueError(
|
216 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
217 |
+
f" {self.num_heads})."
|
218 |
+
)
|
219 |
+
self.scale = self.head_dim**-0.5
|
220 |
+
self.dropout = config.attention_dropout
|
221 |
+
|
222 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
223 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
224 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
225 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
226 |
+
|
227 |
+
def forward(
|
228 |
+
self,
|
229 |
+
hidden_states: torch.Tensor,
|
230 |
+
cu_seqlens: torch.Tensor,
|
231 |
+
rotary_pos_emb: torch.Tensor = None,
|
232 |
+
) -> torch.Tensor:
|
233 |
+
"""Input shape: Time x Channel"""
|
234 |
+
|
235 |
+
q_len, _ = hidden_states.size()
|
236 |
+
|
237 |
+
query_states = self.q_proj(hidden_states)
|
238 |
+
key_states = self.k_proj(hidden_states)
|
239 |
+
value_states = self.v_proj(hidden_states)
|
240 |
+
|
241 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
242 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
243 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
244 |
+
|
245 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
246 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
247 |
+
|
248 |
+
attention_mask = torch.zeros([1, q_len, q_len], device=q.device, dtype=torch.bool)
|
249 |
+
for i in range(1, len(cu_seqlens)):
|
250 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
251 |
+
|
252 |
+
query_states = query_states.transpose(0, 1)
|
253 |
+
key_states = key_states.transpose(0, 1)
|
254 |
+
value_states = value_states.transpose(0, 1)
|
255 |
+
|
256 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
257 |
+
attn_weights = attn_weights + attention_mask
|
258 |
+
|
259 |
+
# upcast attention to fp32
|
260 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
261 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
262 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
263 |
+
|
264 |
+
attn_output = attn_output.transpose(0, 1)
|
265 |
+
attn_output = attn_output.reshape(q_len, -1)
|
266 |
+
attn_output = self.out_proj(attn_output)
|
267 |
+
|
268 |
+
return attn_output
|
269 |
+
|
270 |
+
|
271 |
+
class VisionFlashAttention2(VisionAttention):
|
272 |
+
def __init__(self, *args, **kwargs):
|
273 |
+
super().__init__(*args, **kwargs)
|
274 |
+
|
275 |
+
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
hidden_states: torch.Tensor,
|
279 |
+
cu_seqlens: torch.Tensor,
|
280 |
+
rotary_pos_emb: torch.Tensor = None,
|
281 |
+
) -> torch.Tensor:
|
282 |
+
q_len, _ = hidden_states.size()
|
283 |
+
|
284 |
+
query_states = self.q_proj(hidden_states)
|
285 |
+
key_states = self.k_proj(hidden_states)
|
286 |
+
value_states = self.v_proj(hidden_states)
|
287 |
+
|
288 |
+
# Flash attention requires the input to have the shape
|
289 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
290 |
+
# therefore we just need to keep the original shape
|
291 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
292 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
293 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
294 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
295 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
296 |
+
|
297 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
298 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
299 |
+
q_len, -1
|
300 |
+
)
|
301 |
+
attn_output = self.out_proj(attn_output)
|
302 |
+
|
303 |
+
return attn_output
|
304 |
+
|
305 |
+
|
306 |
+
class VisionSdpaAttention(VisionAttention):
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
hidden_states: torch.Tensor,
|
310 |
+
cu_seqlens: torch.Tensor,
|
311 |
+
rotary_pos_emb: torch.Tensor = None,
|
312 |
+
) -> torch.Tensor:
|
313 |
+
if output_attentions:
|
314 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
315 |
+
logger.warning_once(
|
316 |
+
"DAMOVLVisionModel is using VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
317 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
318 |
+
)
|
319 |
+
return super().forward(
|
320 |
+
hidden_states=hidden_states,
|
321 |
+
cu_seqlens=cu_seqlens,
|
322 |
+
rotary_pos_emb=rotary_pos_emb,
|
323 |
+
)
|
324 |
+
|
325 |
+
seq_length = hidden_states.shape[0]
|
326 |
+
query_states = self.q_proj(hidden_states)
|
327 |
+
key_states = self.k_proj(hidden_states)
|
328 |
+
value_states = self.v_proj(hidden_states)
|
329 |
+
|
330 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
331 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
332 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
333 |
+
|
334 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
335 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
336 |
+
|
337 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
338 |
+
for i in range(1, len(cu_seqlens)):
|
339 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
340 |
+
|
341 |
+
query_states = query_states.transpose(0, 1)
|
342 |
+
key_states = key_states.transpose(0, 1)
|
343 |
+
value_states = value_states.transpose(0, 1)
|
344 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
|
345 |
+
attn_output = attn_output.transpose(0, 1)
|
346 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
347 |
+
attn_output = self.proj(attn_output)
|
348 |
+
return attn_output
|
349 |
+
|
350 |
+
|
351 |
+
DAMOVL_VISION_ATTENTION_CLASSES = {
|
352 |
+
"eager": VisionAttention,
|
353 |
+
"flash_attention_2": VisionFlashAttention2,
|
354 |
+
"sdpa": VisionSdpaAttention,
|
355 |
+
}
|
356 |
+
|
357 |
+
|
358 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->DAMOVL
|
359 |
+
class DAMOVLVisionMLP(nn.Module):
|
360 |
+
def __init__(self, config):
|
361 |
+
super().__init__()
|
362 |
+
self.config = config
|
363 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
364 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
365 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
366 |
+
|
367 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
368 |
+
hidden_states = self.fc1(hidden_states)
|
369 |
+
hidden_states = self.activation_fn(hidden_states)
|
370 |
+
hidden_states = self.fc2(hidden_states)
|
371 |
+
return hidden_states
|
372 |
+
|
373 |
+
|
374 |
+
class DAMOVLVisionEncoderLayer(nn.Module):
|
375 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
376 |
+
super().__init__()
|
377 |
+
self.embed_dim = config.hidden_size
|
378 |
+
self.self_attn = DAMOVL_VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
379 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
380 |
+
self.mlp = DAMOVLVisionMLP(config)
|
381 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
382 |
+
|
383 |
+
# Ignore copy
|
384 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
385 |
+
hidden_states = hidden_states + self.self_attn(
|
386 |
+
self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
387 |
+
)
|
388 |
+
hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
|
389 |
+
return hidden_states
|
390 |
+
|
391 |
+
|
392 |
+
class DAMOVLPreTrainedModel(PreTrainedModel):
|
393 |
+
"""
|
394 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
395 |
+
models.
|
396 |
+
"""
|
397 |
+
|
398 |
+
config_class = DAMOVLVisionConfig
|
399 |
+
base_model_prefix = "damovl"
|
400 |
+
supports_gradient_checkpointing = True
|
401 |
+
_no_split_modules = [
|
402 |
+
"DAMOVLVisionEncoderLayer",
|
403 |
+
"DAMOVLVisionEmbeddings",
|
404 |
+
]
|
405 |
+
_supports_flash_attn_2 = True
|
406 |
+
_supports_sdpa = True
|
407 |
+
|
408 |
+
def _init_weights(self, module):
|
409 |
+
"""Initialize the weights"""
|
410 |
+
if isinstance(module, nn.Embedding):
|
411 |
+
default_flax_embed_init(module.weight)
|
412 |
+
elif isinstance(module, VisionAttention):
|
413 |
+
nn.init.xavier_uniform_(module.q_proj.weight)
|
414 |
+
nn.init.xavier_uniform_(module.k_proj.weight)
|
415 |
+
nn.init.xavier_uniform_(module.v_proj.weight)
|
416 |
+
nn.init.xavier_uniform_(module.out_proj.weight)
|
417 |
+
nn.init.zeros_(module.q_proj.bias)
|
418 |
+
nn.init.zeros_(module.k_proj.bias)
|
419 |
+
nn.init.zeros_(module.v_proj.bias)
|
420 |
+
nn.init.zeros_(module.out_proj.bias)
|
421 |
+
elif isinstance(module, DAMOVLVisionMLP):
|
422 |
+
nn.init.xavier_uniform_(module.fc1.weight)
|
423 |
+
nn.init.xavier_uniform_(module.fc2.weight)
|
424 |
+
nn.init.normal_(module.fc1.bias, std=1e-6)
|
425 |
+
nn.init.normal_(module.fc2.bias, std=1e-6)
|
426 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
427 |
+
lecun_normal_(module.weight)
|
428 |
+
if module.bias is not None:
|
429 |
+
nn.init.zeros_(module.bias)
|
430 |
+
elif isinstance(module, nn.LayerNorm):
|
431 |
+
module.bias.data.zero_()
|
432 |
+
module.weight.data.fill_(1.0)
|
433 |
+
|
434 |
+
|
435 |
+
class DAMOVLVisionEncoder(nn.Module):
|
436 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
437 |
+
super().__init__()
|
438 |
+
self.config = config
|
439 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
440 |
+
self.spatial_merge_size = config.spatial_merge_size
|
441 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
442 |
+
self.layers = nn.ModuleList([DAMOVLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
443 |
+
self.gradient_checkpointing = False
|
444 |
+
|
445 |
+
def rot_pos_emb(self, grid_thw, strides):
|
446 |
+
pos_ids = []
|
447 |
+
for (t, h, w), stride in zip(grid_thw, strides):
|
448 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
449 |
+
hpos_ids = hpos_ids.reshape(
|
450 |
+
h // stride,
|
451 |
+
stride,
|
452 |
+
w // stride,
|
453 |
+
stride,
|
454 |
+
)
|
455 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
456 |
+
hpos_ids = hpos_ids.flatten()
|
457 |
+
|
458 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
459 |
+
wpos_ids = wpos_ids.reshape(
|
460 |
+
h // stride,
|
461 |
+
stride,
|
462 |
+
w // stride,
|
463 |
+
stride,
|
464 |
+
)
|
465 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
466 |
+
wpos_ids = wpos_ids.flatten()
|
467 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
468 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
469 |
+
max_grid_size = grid_thw[:, 1:].max()
|
470 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
471 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
472 |
+
return rotary_pos_emb
|
473 |
+
|
474 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
475 |
+
# BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx`
|
476 |
+
# rotary_pos_emb = []
|
477 |
+
# for thw in grid_thws:
|
478 |
+
# rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0))
|
479 |
+
# rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0)
|
480 |
+
# grid_thws = torch.cat(grid_thws, dim = 0)
|
481 |
+
|
482 |
+
# new version of creating rotary position embedding
|
483 |
+
# grid_thws shapes like [batch_flatten_image_num, 3]
|
484 |
+
# grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py`
|
485 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thws, strides)
|
486 |
+
|
487 |
+
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
488 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
489 |
+
|
490 |
+
for blk in self.layers:
|
491 |
+
if self.gradient_checkpointing and self.training:
|
492 |
+
hidden_states = self._gradient_checkpointing_func(
|
493 |
+
blk.__call__,
|
494 |
+
hidden_states,
|
495 |
+
cu_seqlens,
|
496 |
+
rotary_pos_emb
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
500 |
+
return hidden_states
|
501 |
+
|
502 |
+
|
503 |
+
class DAMOVLVisionTransformer(nn.Module):
|
504 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
505 |
+
super().__init__()
|
506 |
+
self.config = config
|
507 |
+
embed_dim = config.hidden_size
|
508 |
+
|
509 |
+
self.embeddings = DAMOVLVisionEmbeddings(config)
|
510 |
+
self.encoder = DAMOVLVisionEncoder(config)
|
511 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
512 |
+
|
513 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
514 |
+
|
515 |
+
# print(hidden_states)
|
516 |
+
|
517 |
+
# hidden_states = torch.cat(hidden_states, dim = 1)
|
518 |
+
|
519 |
+
hidden_states = self.embeddings(hidden_states)
|
520 |
+
hidden_states = self.encoder(hidden_states, grid_thws, strides)
|
521 |
+
hidden_states = self.post_layernorm(hidden_states)
|
522 |
+
|
523 |
+
return hidden_states
|
524 |
+
|
525 |
+
|
526 |
+
class DAMOVLVisionModel(DAMOVLPreTrainedModel):
|
527 |
+
config_class = DAMOVLVisionConfig
|
528 |
+
main_input_name = "hidden_states"
|
529 |
+
|
530 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
531 |
+
super().__init__(config)
|
532 |
+
|
533 |
+
self.vision_model = DAMOVLVisionTransformer(config)
|
534 |
+
|
535 |
+
# Initialize weights and apply final processing
|
536 |
+
self.post_init()
|
537 |
+
|
538 |
+
def get_input_embeddings(self) -> nn.Module:
|
539 |
+
return self.vision_model.embeddings.patch_embedding
|
540 |
+
|
541 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
542 |
+
return self.vision_model(hidden_states=hidden_states, grid_thws=grid_thws, strides=strides)
|
videollama3/model/encoder.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import (CLIPImageProcessor, CLIPVisionConfig,
|
6 |
+
CLIPVisionModel, SiglipImageProcessor,
|
7 |
+
SiglipVisionConfig, SiglipVisionModel)
|
8 |
+
|
9 |
+
from .qwen2vl_encoder import (Qwen2VisionTransformerPretrainedModel,
|
10 |
+
Qwen2VLImageProcessor, Qwen2VLVisionConfig)
|
11 |
+
|
12 |
+
from .damovl_encoder import (DAMOVLImageProcessor, DAMOVLVisionModel)
|
13 |
+
|
14 |
+
|
15 |
+
class CLIPVisionEncoder(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.is_loaded = False
|
21 |
+
|
22 |
+
self.vision_encoder_name = vision_encoder
|
23 |
+
self.select_layer = args.mm_vision_select_layer
|
24 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
25 |
+
|
26 |
+
if not delay_load:
|
27 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
28 |
+
self.load_model()
|
29 |
+
else:
|
30 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
31 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
32 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
33 |
+
|
34 |
+
def load_model(self):
|
35 |
+
if self.is_loaded:
|
36 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
37 |
+
return
|
38 |
+
|
39 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
|
40 |
+
|
41 |
+
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name,
|
42 |
+
attn_implementation=self.attn_implementation)
|
43 |
+
|
44 |
+
self.is_loaded = True
|
45 |
+
|
46 |
+
def feature_select(self, image_forward_outs):
|
47 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
48 |
+
if self.select_feature == 'patch':
|
49 |
+
image_features = image_features[:, 1:]
|
50 |
+
elif self.select_feature == 'cls_patch':
|
51 |
+
image_features = image_features
|
52 |
+
else:
|
53 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
54 |
+
return image_features
|
55 |
+
|
56 |
+
def forward(self, images, **kwargs):
|
57 |
+
images = torch.cat(images)
|
58 |
+
if type(images) is list:
|
59 |
+
image_features = []
|
60 |
+
for image in images:
|
61 |
+
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
62 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
63 |
+
image_features.append(image_feature)
|
64 |
+
else:
|
65 |
+
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
66 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
67 |
+
|
68 |
+
return image_features
|
69 |
+
|
70 |
+
@property
|
71 |
+
def dummy_feature(self):
|
72 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def dtype(self):
|
76 |
+
return self.vision_encoder.dtype
|
77 |
+
|
78 |
+
@property
|
79 |
+
def device(self):
|
80 |
+
return self.vision_encoder.device
|
81 |
+
|
82 |
+
@property
|
83 |
+
def config(self):
|
84 |
+
if self.is_loaded:
|
85 |
+
return self.vision_encoder.config
|
86 |
+
else:
|
87 |
+
return self.cfg_only
|
88 |
+
|
89 |
+
@property
|
90 |
+
def hidden_size(self):
|
91 |
+
return self.config.hidden_size
|
92 |
+
|
93 |
+
@property
|
94 |
+
def num_patches(self):
|
95 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
96 |
+
|
97 |
+
@property
|
98 |
+
def num_patches_per_side(self):
|
99 |
+
return self.config.image_size // self.config.patch_size
|
100 |
+
|
101 |
+
@property
|
102 |
+
def image_size(self):
|
103 |
+
return self.config.image_size
|
104 |
+
|
105 |
+
|
106 |
+
class SiglipVisionEncoder(nn.Module):
|
107 |
+
|
108 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
self.is_loaded = False
|
112 |
+
|
113 |
+
self.vision_encoder_name = vision_encoder
|
114 |
+
self.select_layer = args.mm_vision_select_layer
|
115 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
116 |
+
|
117 |
+
if not delay_load:
|
118 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
119 |
+
self.load_model()
|
120 |
+
else:
|
121 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
122 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
123 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
124 |
+
|
125 |
+
def load_model(self):
|
126 |
+
if self.is_loaded:
|
127 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
128 |
+
return
|
129 |
+
|
130 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_encoder_name)
|
131 |
+
|
132 |
+
self.vision_encoder = SiglipVisionModel.from_pretrained(self.vision_encoder_name,
|
133 |
+
attn_implementation=self.attn_implementation)
|
134 |
+
|
135 |
+
self.is_loaded = True
|
136 |
+
|
137 |
+
def feature_select(self, image_forward_outs):
|
138 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
139 |
+
if self.select_feature == 'patch':
|
140 |
+
image_features = image_features
|
141 |
+
else:
|
142 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
143 |
+
return image_features
|
144 |
+
|
145 |
+
def forward(self, images, **kwargs):
|
146 |
+
images = torch.cat(images)
|
147 |
+
if type(images) is list:
|
148 |
+
image_features = []
|
149 |
+
for image in images:
|
150 |
+
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
151 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
152 |
+
image_features.append(image_feature)
|
153 |
+
else:
|
154 |
+
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
155 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
156 |
+
|
157 |
+
return image_features
|
158 |
+
|
159 |
+
@property
|
160 |
+
def dummy_feature(self):
|
161 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
162 |
+
|
163 |
+
@property
|
164 |
+
def dtype(self):
|
165 |
+
return self.vision_encoder.dtype
|
166 |
+
|
167 |
+
@property
|
168 |
+
def device(self):
|
169 |
+
return self.vision_encoder.device
|
170 |
+
|
171 |
+
@property
|
172 |
+
def config(self):
|
173 |
+
if self.is_loaded:
|
174 |
+
return self.vision_encoder.config
|
175 |
+
else:
|
176 |
+
return self.cfg_only
|
177 |
+
|
178 |
+
@property
|
179 |
+
def hidden_size(self):
|
180 |
+
return self.config.hidden_size
|
181 |
+
|
182 |
+
@property
|
183 |
+
def num_patches(self):
|
184 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
185 |
+
|
186 |
+
@property
|
187 |
+
def num_patches_per_side(self):
|
188 |
+
return self.config.image_size // self.config.patch_size
|
189 |
+
|
190 |
+
@property
|
191 |
+
def image_size(self):
|
192 |
+
return self.config.image_size
|
193 |
+
|
194 |
+
|
195 |
+
class Qwen2VLVisionEncoder(nn.Module):
|
196 |
+
|
197 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.is_loaded = False
|
201 |
+
|
202 |
+
self.vision_encoder_name = vision_encoder
|
203 |
+
self.select_layer = args.mm_vision_select_layer
|
204 |
+
|
205 |
+
if not delay_load:
|
206 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
207 |
+
self.load_model(args)
|
208 |
+
else:
|
209 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
210 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
211 |
+
self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
|
212 |
+
|
213 |
+
def load_model(self, args):
|
214 |
+
if self.is_loaded:
|
215 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
216 |
+
return
|
217 |
+
|
218 |
+
# merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
|
219 |
+
# for stage 3, the merge_size is set to 2 by argments.
|
220 |
+
self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.vision_encoder_name)
|
221 |
+
self.image_processor.merge_size = args.spatial_merge_size
|
222 |
+
# NOTE: The maximum number of vision tokens is 8192 by default.
|
223 |
+
mm_max_length = args.mm_max_length if hasattr(args, 'mm_max_length') else 9477 // (args.spatial_merge_size**2)
|
224 |
+
self.image_processor.max_pixels = mm_max_length * (args.spatial_merge_size**2 * self.image_processor.patch_size**2)
|
225 |
+
self.image_processor.size["max_pixels"] = self.image_processor.max_pixels
|
226 |
+
|
227 |
+
# merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
|
228 |
+
self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
|
229 |
+
self.cfg_only.spatial_merge_size = args.spatial_merge_size
|
230 |
+
|
231 |
+
self.vision_encoder = Qwen2VisionTransformerPretrainedModel.from_pretrained(
|
232 |
+
self.vision_encoder_name,
|
233 |
+
config=self.cfg_only,
|
234 |
+
torch_dtype=args.torch_dtype,
|
235 |
+
attn_implementation=self.attn_implementation)
|
236 |
+
|
237 |
+
self.is_loaded = True
|
238 |
+
|
239 |
+
def forward(self, images, grid_thws, strides, **kwargs):
|
240 |
+
images = [image for sub_images in images for image in sub_images]
|
241 |
+
grid_thws = [grid_thw for sub_grid_thws in grid_thws for grid_thw in sub_grid_thws]
|
242 |
+
strides = [stride for sub_strides in strides for stride in sub_strides]
|
243 |
+
|
244 |
+
images = torch.cat(images, dim=0)
|
245 |
+
grid_thws = torch.cat(grid_thws, dim=0)
|
246 |
+
|
247 |
+
image_features = self.vision_encoder(images, grid_thws, strides=strides)
|
248 |
+
|
249 |
+
return image_features
|
250 |
+
|
251 |
+
@property
|
252 |
+
def dummy_feature(self):
|
253 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
254 |
+
|
255 |
+
@property
|
256 |
+
def dtype(self):
|
257 |
+
return self.vision_encoder.dtype
|
258 |
+
|
259 |
+
@property
|
260 |
+
def device(self):
|
261 |
+
return self.vision_encoder.device
|
262 |
+
|
263 |
+
@property
|
264 |
+
def config(self):
|
265 |
+
if self.is_loaded:
|
266 |
+
return self.vision_encoder.config
|
267 |
+
else:
|
268 |
+
return self.cfg_only
|
269 |
+
|
270 |
+
@property
|
271 |
+
def hidden_size(self):
|
272 |
+
return self.config.hidden_size
|
273 |
+
|
274 |
+
@property
|
275 |
+
def num_patches(self):
|
276 |
+
return -1
|
277 |
+
|
278 |
+
@property
|
279 |
+
def num_patches_per_side(self):
|
280 |
+
return -1
|
281 |
+
|
282 |
+
@property
|
283 |
+
def image_size(self):
|
284 |
+
return 14 * self.vision_encoder.config.spatial_merge_size
|
285 |
+
|
286 |
+
|
287 |
+
class DAMOVLVisionEncoder(nn.Module):
|
288 |
+
|
289 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
self.is_loaded = False
|
293 |
+
|
294 |
+
self.vision_encoder_name = vision_encoder
|
295 |
+
self.args = args
|
296 |
+
|
297 |
+
if not delay_load:
|
298 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
299 |
+
self.load_model(self.args)
|
300 |
+
else:
|
301 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
302 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
303 |
+
self.cfg_only = DAMOVLVisionConfig.from_pretrained(self.vision_encoder_name)
|
304 |
+
|
305 |
+
def load_model(self, args):
|
306 |
+
if self.is_loaded:
|
307 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
308 |
+
return
|
309 |
+
|
310 |
+
# merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
|
311 |
+
# for stage 3, the merge_size is set to 2 by argments.
|
312 |
+
self.image_processor = DAMOVLImageProcessor.from_pretrained(self.vision_encoder_name)
|
313 |
+
self.image_processor.merge_size = args.spatial_merge_size
|
314 |
+
# NOTE: The maximum number of vision tokens is 8192 by default.
|
315 |
+
mm_max_length = args.mm_max_length if hasattr(args, 'mm_max_length') else 9477 // (args.spatial_merge_size**2)
|
316 |
+
self.image_processor.max_pixels = mm_max_length * (args.spatial_merge_size**2 * self.image_processor.patch_size**2)
|
317 |
+
self.image_processor.size["max_pixels"] = self.image_processor.max_pixels
|
318 |
+
|
319 |
+
# merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
|
320 |
+
self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
|
321 |
+
self.cfg_only.spatial_merge_size = args.spatial_merge_size
|
322 |
+
|
323 |
+
self.vision_encoder = DAMOVLVisionModel.from_pretrained(
|
324 |
+
self.vision_encoder_name,
|
325 |
+
spatial_merge_size=args.spatial_merge_size,
|
326 |
+
torch_dtype=args.torch_dtype,
|
327 |
+
attn_implementation=self.attn_implementation)
|
328 |
+
|
329 |
+
self.is_loaded = True
|
330 |
+
|
331 |
+
def forward(self, images, grid_thws, strides, **kwargs):
|
332 |
+
images = [image for sub_images in images for image in sub_images]
|
333 |
+
grid_thws = [grid_thw for sub_grid_thws in grid_thws for grid_thw in sub_grid_thws]
|
334 |
+
strides = [stride for sub_strides in strides for stride in sub_strides]
|
335 |
+
|
336 |
+
images = torch.cat(images, dim=0)
|
337 |
+
grid_thws = torch.cat(grid_thws, dim=0)
|
338 |
+
|
339 |
+
image_features = self.vision_encoder(images, grid_thws, strides)
|
340 |
+
|
341 |
+
return image_features
|
342 |
+
|
343 |
+
@property
|
344 |
+
def dummy_feature(self):
|
345 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
346 |
+
|
347 |
+
@property
|
348 |
+
def dtype(self):
|
349 |
+
return self.vision_encoder.dtype
|
350 |
+
|
351 |
+
@property
|
352 |
+
def device(self):
|
353 |
+
return self.vision_encoder.device
|
354 |
+
|
355 |
+
@property
|
356 |
+
def config(self):
|
357 |
+
if self.is_loaded:
|
358 |
+
return self.vision_encoder.config
|
359 |
+
else:
|
360 |
+
return self.cfg_only
|
361 |
+
|
362 |
+
@property
|
363 |
+
def hidden_size(self):
|
364 |
+
return self.config.hidden_size
|
365 |
+
|
366 |
+
@property
|
367 |
+
def num_patches(self):
|
368 |
+
return -1
|
369 |
+
|
370 |
+
@property
|
371 |
+
def num_patches_per_side(self):
|
372 |
+
return -1
|
373 |
+
|
374 |
+
@property
|
375 |
+
def image_size(self):
|
376 |
+
return 14 * self.vision_encoder.config.spatial_merge_size
|
377 |
+
|
378 |
+
|
379 |
+
def build_vision_encoder(vision_encoder_cfg, **kwargs):
|
380 |
+
|
381 |
+
vision_encoder = getattr(vision_encoder_cfg, 'mm_vision_encoder', getattr(vision_encoder_cfg, 'vision_encoder', None))
|
382 |
+
|
383 |
+
vision_encoder = DAMOVLVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
|
384 |
+
|
385 |
+
return vision_encoder
|
videollama3/model/processor.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""
|
21 |
+
Processor class for VideoLLaMA3.
|
22 |
+
"""
|
23 |
+
import copy
|
24 |
+
import math
|
25 |
+
import warnings
|
26 |
+
from typing import List, Union, Dict, Optional
|
27 |
+
|
28 |
+
import torch
|
29 |
+
from transformers.feature_extraction_utils import BatchFeature
|
30 |
+
from transformers.image_utils import ImageInput, VideoInput
|
31 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
32 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
33 |
+
|
34 |
+
import sys
|
35 |
+
sys.path.append(".")
|
36 |
+
from videollama3.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
|
37 |
+
|
38 |
+
|
39 |
+
DEFAULT_CHAT_TEMPLATE = """
|
40 |
+
{%- set identifier = 'im' %}
|
41 |
+
{% for message in messages %}
|
42 |
+
{% if message['role'] == 'stream' %}
|
43 |
+
{% set identifier = 'stream' %}
|
44 |
+
{% else %}
|
45 |
+
{% set identifier = 'im' %}
|
46 |
+
{% endif %}
|
47 |
+
{{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}
|
48 |
+
{% if message['content'] is string %}
|
49 |
+
{{- message['content'] + '<|' + identifier + '_end|>\n' -}}
|
50 |
+
{% else %}
|
51 |
+
{% for content in message['content'] %}
|
52 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
53 |
+
{% if 'time' in content %}
|
54 |
+
{{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}
|
55 |
+
{% endif %}
|
56 |
+
"""
|
57 |
+
DEFAULT_CHAT_TEMPLATE += """
|
58 |
+
{{- '%s\n' -}}
|
59 |
+
""" % DEFAULT_IMAGE_TOKEN
|
60 |
+
DEFAULT_CHAT_TEMPLATE += """
|
61 |
+
{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}
|
62 |
+
{% for i in range(content['num_frames']) %}
|
63 |
+
{% if 'time' in content %}
|
64 |
+
{{- 'Time ' + content['time'][i] | round(1) | string + 's:' -}}
|
65 |
+
{% endif %}
|
66 |
+
{% if i < content['num_frames'] - 1 %}
|
67 |
+
"""
|
68 |
+
DEFAULT_CHAT_TEMPLATE += """
|
69 |
+
{{- '%s,' -}}
|
70 |
+
""" % DEFAULT_IMAGE_TOKEN
|
71 |
+
DEFAULT_CHAT_TEMPLATE += """
|
72 |
+
{% else %}
|
73 |
+
"""
|
74 |
+
DEFAULT_CHAT_TEMPLATE += """
|
75 |
+
{{- '%s\n' -}}
|
76 |
+
""" % DEFAULT_IMAGE_TOKEN
|
77 |
+
DEFAULT_CHAT_TEMPLATE += """
|
78 |
+
{% endif %}
|
79 |
+
{% endfor %}
|
80 |
+
{% elif 'text' in content %}
|
81 |
+
{{- content['text'] -}}
|
82 |
+
{% endif %}
|
83 |
+
{% endfor %}
|
84 |
+
{{- '<|' + identifier + '_end|>\n' -}}
|
85 |
+
{% endif %}
|
86 |
+
{% endfor %}
|
87 |
+
{% if add_generation_prompt %}
|
88 |
+
{{- '<|im_start|>assistant\n' -}}
|
89 |
+
{% endif %}
|
90 |
+
"""
|
91 |
+
|
92 |
+
|
93 |
+
class Videollama3ProcessorKwargs(ProcessingKwargs, total=False):
|
94 |
+
_defaults = {
|
95 |
+
"text_kwargs": {
|
96 |
+
"padding": False,
|
97 |
+
},
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
class Videollama3Processor(ProcessorMixin):
|
102 |
+
r"""
|
103 |
+
Modified from Qwen2VLProcessor
|
104 |
+
Args:
|
105 |
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
106 |
+
The image processor is a required input.
|
107 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
108 |
+
The tokenizer is a required input.
|
109 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
110 |
+
in a chat into a tokenizable string.
|
111 |
+
"""
|
112 |
+
|
113 |
+
attributes = ["image_processor", "tokenizer"]
|
114 |
+
valid_kwargs = ["chat_template"]
|
115 |
+
image_processor_class = "Qwen2VLImageProcessor"
|
116 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
117 |
+
|
118 |
+
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
119 |
+
if chat_template is None:
|
120 |
+
chat_template = DEFAULT_CHAT_TEMPLATE
|
121 |
+
# super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
122 |
+
tokenizer.chat_template = chat_template
|
123 |
+
self.image_processor = image_processor
|
124 |
+
self.tokenizer = tokenizer
|
125 |
+
self.generation_prompt = self._infer_generation_prompt()
|
126 |
+
self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
|
127 |
+
self.generation_prompt_length = len(self.generation_prompt_ids[0])
|
128 |
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
|
129 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
130 |
+
|
131 |
+
def get_generation_prompt(self):
|
132 |
+
return self.generation_prompt
|
133 |
+
|
134 |
+
def get_generation_prompt_ids(self):
|
135 |
+
return self.generation_prompt_ids
|
136 |
+
|
137 |
+
def _infer_generation_prompt(self):
|
138 |
+
pseudo_message = [{"role": "user", "content": ""}]
|
139 |
+
instruction = self.tokenizer.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
|
140 |
+
conversation = self.tokenizer.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
|
141 |
+
return instruction.replace(conversation, "")
|
142 |
+
|
143 |
+
def _process_text_with_label(
|
144 |
+
self,
|
145 |
+
text: List[Dict],
|
146 |
+
image_grid_thw: torch.Tensor = None,
|
147 |
+
image_downsampling: Optional[int] = None,
|
148 |
+
**kwargs,
|
149 |
+
):
|
150 |
+
assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
|
151 |
+
assert isinstance(text[0], dict), "When return_labels=True, text must be a list of messages."
|
152 |
+
|
153 |
+
input_ids_list = []
|
154 |
+
targets_list = []
|
155 |
+
sample_types_list = []
|
156 |
+
image_idx = 0
|
157 |
+
|
158 |
+
for message_idx, message in enumerate(text):
|
159 |
+
# 1. set chat template and append image tokens
|
160 |
+
prompt = self.tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=False)
|
161 |
+
prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
|
162 |
+
prompt = []
|
163 |
+
for chunk_idx in range(len(prompt_chunks) - 1):
|
164 |
+
prompt.append(prompt_chunks[chunk_idx])
|
165 |
+
thw = image_grid_thw[image_idx]
|
166 |
+
prompt.append(DEFAULT_IMAGE_TOKEN * (thw.prod() / image_downsampling**2).long())
|
167 |
+
image_idx += 1
|
168 |
+
prompt.append(prompt_chunks[-1])
|
169 |
+
prompt = "".join(prompt)
|
170 |
+
|
171 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")[0]
|
172 |
+
input_ids_list.append(input_ids)
|
173 |
+
|
174 |
+
targets = torch.full_like(input_ids, IGNORE_INDEX)
|
175 |
+
sample_types = torch.full_like(input_ids, IGNORE_INDEX)
|
176 |
+
if message["role"] == "assistant":
|
177 |
+
targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
|
178 |
+
elif message["role"] == "stream":
|
179 |
+
diff = torch.diff((input_ids == self.image_token_id).float())
|
180 |
+
image_end_indices = torch.nonzero(diff < 0)[:, 0]
|
181 |
+
targets[image_end_indices + 1] = input_ids[image_end_indices + 1]
|
182 |
+
sample_types = targets.clone()
|
183 |
+
sample_types[torch.logical_and(sample_types > 0, sample_types != self.eos_token_id)] = 0
|
184 |
+
targets[-2] = input_ids[-2] # <|im_end|>
|
185 |
+
|
186 |
+
# if message_idx > 0 and text[message_idx - 1]["role"] == "stream":
|
187 |
+
# targets[0] = input_ids[0]
|
188 |
+
# # TODO: consider non-special tokens
|
189 |
+
# sample_types[0] = input_ids[0]
|
190 |
+
|
191 |
+
targets_list.append(targets)
|
192 |
+
sample_types_list.append(sample_types)
|
193 |
+
|
194 |
+
assert len(image_grid_thw) == image_idx, "Number of images does not match the number of image tokens in the text."
|
195 |
+
|
196 |
+
targets = torch.cat(targets_list)
|
197 |
+
sample_types = torch.cat(sample_types_list)
|
198 |
+
types, counts = torch.unique(sample_types[sample_types > -1], return_counts=True)
|
199 |
+
|
200 |
+
if len(types) > 0:
|
201 |
+
target_num_samples = counts.amin()
|
202 |
+
|
203 |
+
for type_id, type_count in zip(types, counts):
|
204 |
+
if type_count > target_num_samples:
|
205 |
+
indices = torch.nonzero(sample_types == type_id)[:, 0]
|
206 |
+
random_selector = torch.randperm(indices.size(0))[:-target_num_samples]
|
207 |
+
targets[indices[random_selector]] = IGNORE_INDEX
|
208 |
+
sample_types[indices[random_selector]] = -1
|
209 |
+
|
210 |
+
text_inputs = {
|
211 |
+
"input_ids": torch.cat(input_ids_list),
|
212 |
+
"labels": targets,
|
213 |
+
}
|
214 |
+
|
215 |
+
return text_inputs
|
216 |
+
|
217 |
+
def _process_text_without_label(
|
218 |
+
self,
|
219 |
+
text: Union[List[str], List[Dict]],
|
220 |
+
image_grid_thw: torch.Tensor = None,
|
221 |
+
image_downsampling: Optional[int] = None,
|
222 |
+
**kwargs,
|
223 |
+
):
|
224 |
+
if isinstance(text[0], dict):
|
225 |
+
warnings.warn("Input text is a list of messages. Automatically convert it to a string with 'apply_chat_template' with generation prompt.")
|
226 |
+
text = [self.tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)]
|
227 |
+
|
228 |
+
image_idx = 0
|
229 |
+
for i in range(len(text)):
|
230 |
+
while DEFAULT_IMAGE_TOKEN in text[i]:
|
231 |
+
thw = image_grid_thw[image_idx]
|
232 |
+
text[i] = text[i].replace(DEFAULT_IMAGE_TOKEN, "<placeholder>" * (thw.prod() / image_downsampling**2).long(), 1)
|
233 |
+
image_idx += 1
|
234 |
+
text[i] = text[i].replace("<placeholder>", DEFAULT_IMAGE_TOKEN)
|
235 |
+
assert len(image_grid_thw) == image_idx, "Number of images does not match the number of image tokens in the text."
|
236 |
+
|
237 |
+
text_inputs = self.tokenizer(text, **kwargs)
|
238 |
+
return text_inputs
|
239 |
+
|
240 |
+
def _process_text(
|
241 |
+
self,
|
242 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]],
|
243 |
+
image_grid_thw: torch.Tensor = None,
|
244 |
+
image_downsampling: Optional[int] = None,
|
245 |
+
return_labels: bool = False,
|
246 |
+
**kwargs,
|
247 |
+
):
|
248 |
+
if not isinstance(text, (list, tuple)):
|
249 |
+
text = [text]
|
250 |
+
assert len(text), "At least one text must be provided."
|
251 |
+
|
252 |
+
if return_labels:
|
253 |
+
return self._process_text_with_label(text, image_grid_thw, image_downsampling, **kwargs)
|
254 |
+
return self._process_text_without_label(text, image_grid_thw, image_downsampling, **kwargs)
|
255 |
+
|
256 |
+
def _process_image(
|
257 |
+
self,
|
258 |
+
images: ImageInput = None,
|
259 |
+
image_downsampling: Optional[int] = None,
|
260 |
+
**kwargs,
|
261 |
+
):
|
262 |
+
if image_downsampling is None:
|
263 |
+
image_downsampling = self.image_processor.merge_size
|
264 |
+
|
265 |
+
image_inputs = {
|
266 |
+
"images": [],
|
267 |
+
"grid_thws": [],
|
268 |
+
"image_downsampling": image_downsampling
|
269 |
+
}
|
270 |
+
if images is not None and len(images) > 0:
|
271 |
+
num_images = kwargs.get('num_images', len(images))
|
272 |
+
if 'num_images' in kwargs:
|
273 |
+
kwargs.pop('num_images')
|
274 |
+
for image in images:
|
275 |
+
outputs = self.image_processor(images=image, num_images=num_images, image_downsampling=image_downsampling, **kwargs)
|
276 |
+
# images shapes like: [tensor([patches, 1176]), ...]
|
277 |
+
# grid_thws shapes like: tensor([num_images, 3])
|
278 |
+
|
279 |
+
# flatten_patches1 = outputs["pixel_values"].reshape(26, 46, 3, -1)
|
280 |
+
# from matplotlib import pyplot as plt
|
281 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
282 |
+
# plt.savefig('9.png')
|
283 |
+
|
284 |
+
image_inputs["images"].append(outputs["pixel_values"]) #正常的
|
285 |
+
|
286 |
+
# flatten_patches1 = image_inputs["images"][0].reshape(26, 46, 3, -1)
|
287 |
+
# from matplotlib import pyplot as plt
|
288 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
289 |
+
# plt.savefig('12.png')
|
290 |
+
image_inputs["grid_thws"].append(outputs["image_grid_thw"])
|
291 |
+
|
292 |
+
return image_inputs
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
def __call__(
|
297 |
+
self,
|
298 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]] = None,
|
299 |
+
images: ImageInput = None,
|
300 |
+
image_downsampling: Optional[int] = None,
|
301 |
+
return_labels: bool = False,
|
302 |
+
**kwargs: Unpack[Videollama3ProcessorKwargs],
|
303 |
+
) -> BatchFeature:
|
304 |
+
"""
|
305 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
306 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
307 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
308 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
312 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
313 |
+
tensor. Both channels-first and channels-last formats are supported.
|
314 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
315 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
316 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
317 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
318 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
319 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
320 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
321 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
322 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
323 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
327 |
+
|
328 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
329 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
330 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
331 |
+
`None`).
|
332 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
333 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
334 |
+
"""
|
335 |
+
output_kwargs = self._merge_kwargs(
|
336 |
+
Videollama3ProcessorKwargs,
|
337 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
338 |
+
**kwargs,
|
339 |
+
)
|
340 |
+
output_kwargs["text_kwargs"].pop("padding")
|
341 |
+
output_kwargs["text_kwargs"].pop("padding_side")
|
342 |
+
|
343 |
+
image_inputs = self._process_image(images, image_downsampling, **output_kwargs["images_kwargs"])
|
344 |
+
text_inputs = self._process_text(text, image_inputs["grid_thws"], image_downsampling, return_labels, **output_kwargs["text_kwargs"])
|
345 |
+
|
346 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
347 |
+
|
348 |
+
def batch_decode(self, *args, **kwargs):
|
349 |
+
"""
|
350 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
351 |
+
refer to the docstring of this method for more information.
|
352 |
+
"""
|
353 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
354 |
+
|
355 |
+
def decode(self, *args, **kwargs):
|
356 |
+
"""
|
357 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
358 |
+
the docstring of this method for more information.
|
359 |
+
"""
|
360 |
+
return self.tokenizer.decode(*args, **kwargs)
|
361 |
+
|
362 |
+
@property
|
363 |
+
def model_input_names(self):
|
364 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
365 |
+
image_processor_input_names = self.image_processor.model_input_names
|
366 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
videollama3/model/projector.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Alibaba DAMO Academy
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
|
19 |
+
import einops
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from timm.models.layers import LayerNorm, LayerNorm2d
|
24 |
+
from timm.models.regnet import RegStage
|
25 |
+
from transformers import TRANSFORMERS_CACHE
|
26 |
+
|
27 |
+
|
28 |
+
def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"):
|
29 |
+
revision = "main"
|
30 |
+
# 1. parse the downloaded cache folder
|
31 |
+
if cache_dir is None:
|
32 |
+
cache_dir = TRANSFORMERS_CACHE
|
33 |
+
else:
|
34 |
+
cache_dir = cache_dir
|
35 |
+
object_id = repo_id.replace("/", "--")
|
36 |
+
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
37 |
+
# 2. resolve refs (for instance to convert main to the associated commit sha)
|
38 |
+
refs_dir = os.path.join(repo_cache, "refs")
|
39 |
+
if os.path.isdir(refs_dir):
|
40 |
+
revision_file = os.path.join(refs_dir, revision)
|
41 |
+
if os.path.isfile(revision_file):
|
42 |
+
with open(revision_file) as f:
|
43 |
+
revision = f.read()
|
44 |
+
# 3. acquire the snapshot folder
|
45 |
+
folder = os.path.join(repo_cache, "snapshots", revision)
|
46 |
+
|
47 |
+
return folder
|
48 |
+
|
49 |
+
|
50 |
+
def load_mm_projector(model_path, cache_dir=None, token=None):
|
51 |
+
if os.path.exists(os.path.join(model_path, 'mm_projector.bin')):
|
52 |
+
is_local = True
|
53 |
+
folder = model_path
|
54 |
+
else:
|
55 |
+
is_local = False
|
56 |
+
folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model")
|
57 |
+
if not os.path.exists(os.path.join(folder, 'mm_projector.bin')):
|
58 |
+
# downloading from remote repo
|
59 |
+
from huggingface_hub import snapshot_download
|
60 |
+
snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token)
|
61 |
+
|
62 |
+
mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu')
|
63 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
64 |
+
return mm_projector_weights
|
65 |
+
|
66 |
+
|
67 |
+
class IdentityMap(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
def forward(self, x, *args, **kwargs):
|
73 |
+
return x
|
74 |
+
|
75 |
+
@property
|
76 |
+
def config(self):
|
77 |
+
return {"mm_projector_type": 'identity'}
|
78 |
+
|
79 |
+
|
80 |
+
def build_mlp(depth, hidden_size, output_hidden_size):
|
81 |
+
modules = [nn.Linear(hidden_size, output_hidden_size)]
|
82 |
+
for _ in range(1, depth):
|
83 |
+
modules.append(nn.GELU())
|
84 |
+
modules.append(nn.Linear(output_hidden_size, output_hidden_size))
|
85 |
+
return nn.Sequential(*modules)
|
86 |
+
|
87 |
+
|
88 |
+
class SimSpatialConv(nn.Module):
|
89 |
+
|
90 |
+
def __init__(self, config, downsample=(2, 2), padding=1, depth=1, mlp_depth=2):
|
91 |
+
super().__init__()
|
92 |
+
self.encoder_hidden_size = encoder_hidden_size = config.mm_hidden_size
|
93 |
+
self.output_hidden_size = output_hidden_size = config.hidden_size
|
94 |
+
self.downsample = downsample
|
95 |
+
self.padding = padding
|
96 |
+
self.sampler = nn.Sequential(
|
97 |
+
nn.Conv2d(
|
98 |
+
in_channels=self.encoder_hidden_size,
|
99 |
+
out_channels=4 * self.encoder_hidden_size,
|
100 |
+
kernel_size=self.downsample,
|
101 |
+
stride=self.downsample,
|
102 |
+
padding=self.padding,
|
103 |
+
bias=True
|
104 |
+
),
|
105 |
+
nn.SiLU(),
|
106 |
+
)
|
107 |
+
self.readout = build_mlp(mlp_depth, 4 * self.encoder_hidden_size, self.output_hidden_size)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
hw = int(x.size(1) ** 0.5)
|
111 |
+
x = einops.rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
|
112 |
+
x = self.sampler(x)
|
113 |
+
x = einops.rearrange(x, "b d h w -> b (h w) d")
|
114 |
+
x = self.readout(x)
|
115 |
+
return x
|
116 |
+
|
117 |
+
def cal_proj_size(self, input_size):
|
118 |
+
if isinstance(input_size, int):
|
119 |
+
input_size = (input_size, input_size)
|
120 |
+
height = math.ceil((input_size[0] + self.padding) / self.downsample[0])
|
121 |
+
width = math.ceil((input_size[1] + self.padding) / self.downsample[1])
|
122 |
+
return height * width
|
123 |
+
|
124 |
+
|
125 |
+
class MlpGeluProjector(nn.Module):
|
126 |
+
def __init__(self, config, projector_type):
|
127 |
+
super().__init__()
|
128 |
+
|
129 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
130 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
131 |
+
|
132 |
+
self.readout = build_mlp(mlp_depth, config.mm_hidden_size, config.hidden_size)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
x = self.readout(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
def cal_proj_size(self, input_size):
|
139 |
+
if isinstance(input_size, int):
|
140 |
+
input_size = (input_size, input_size)
|
141 |
+
height = input_size[0]
|
142 |
+
width = input_size[1]
|
143 |
+
return height * width
|
144 |
+
|
145 |
+
|
146 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
147 |
+
# videollama3 projector only support image-wise operation now, i.e., prohibit the temporal aggregation
|
148 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
149 |
+
|
150 |
+
if projector_type == "linear":
|
151 |
+
# NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
|
152 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
153 |
+
elif projector_type == "simp_spatial_conv":
|
154 |
+
return SimSpatialConv(config)
|
155 |
+
elif projector_type.startswith("mlp"):
|
156 |
+
return MlpGeluProjector(config, projector_type)
|
157 |
+
if projector_type == 'identity':
|
158 |
+
return IdentityMap()
|
159 |
+
|
160 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
videollama3/model/qwen2vl_encoder/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig
|
2 |
+
from .image_processing import Qwen2VLImageProcessor
|
3 |
+
from .modeling_qwen2vl_encoder import Qwen2VisionTransformerPretrainedModel
|
videollama3/model/qwen2vl_encoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (432 Bytes). View file
|
|
videollama3/model/qwen2vl_encoder/__pycache__/configuration_qwen2vl_encoder.cpython-310.pyc
ADDED
Binary file (1.92 kB). View file
|
|
videollama3/model/qwen2vl_encoder/__pycache__/image_processing.cpython-310.pyc
ADDED
Binary file (16.9 kB). View file
|
|
videollama3/model/qwen2vl_encoder/__pycache__/modeling_qwen2vl_encoder.cpython-310.pyc
ADDED
Binary file (12.7 kB). View file
|
|
videollama3/model/qwen2vl_encoder/configuration_qwen2vl_encoder.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Qwen2VL model configuration"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
28 |
+
model_type = "qwen2_vl"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
depth=32,
|
33 |
+
embed_dim=1280,
|
34 |
+
hidden_size=3584,
|
35 |
+
hidden_act="quick_gelu",
|
36 |
+
mlp_ratio=4,
|
37 |
+
num_heads=16,
|
38 |
+
in_channels=3,
|
39 |
+
patch_size=14,
|
40 |
+
spatial_merge_size=2,
|
41 |
+
temporal_patch_size=2,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
|
46 |
+
self.depth = depth
|
47 |
+
self.embed_dim = embed_dim
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.hidden_act = hidden_act
|
50 |
+
self.mlp_ratio = mlp_ratio
|
51 |
+
self.num_heads = num_heads
|
52 |
+
self.in_channels = in_channels
|
53 |
+
self.patch_size = patch_size
|
54 |
+
self.spatial_merge_size = spatial_merge_size
|
55 |
+
self.temporal_patch_size = temporal_patch_size
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
59 |
+
cls._set_token_in_kwargs(kwargs)
|
60 |
+
|
61 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
62 |
+
|
63 |
+
# if config_dict.get("model_type") == "qwen2_vl":
|
64 |
+
# config_dict = config_dict["vision_config"]
|
65 |
+
|
66 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
67 |
+
logger.warning(
|
68 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
69 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
70 |
+
)
|
71 |
+
|
72 |
+
return cls.from_dict(config_dict, **kwargs)
|
videollama3/model/qwen2vl_encoder/image_processing.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""Image processor class for Qwen2-VL."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from typing import Dict, List, Optional, Union
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
28 |
+
from transformers.image_transforms import (
|
29 |
+
convert_to_rgb,
|
30 |
+
resize,
|
31 |
+
to_channel_dimension_format,
|
32 |
+
)
|
33 |
+
from transformers.image_utils import (
|
34 |
+
OPENAI_CLIP_MEAN,
|
35 |
+
OPENAI_CLIP_STD,
|
36 |
+
ChannelDimension,
|
37 |
+
ImageInput,
|
38 |
+
PILImageResampling,
|
39 |
+
VideoInput,
|
40 |
+
get_image_size,
|
41 |
+
infer_channel_dimension_format,
|
42 |
+
is_scaled_image,
|
43 |
+
is_valid_image,
|
44 |
+
make_list_of_images,
|
45 |
+
to_numpy_array,
|
46 |
+
valid_images,
|
47 |
+
validate_preprocess_arguments,
|
48 |
+
)
|
49 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
50 |
+
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
if is_vision_available():
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
|
59 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
60 |
+
"""
|
61 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
65 |
+
The input image.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
list: A list of images.
|
69 |
+
"""
|
70 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
71 |
+
return [img for img_list in images for img in img_list]
|
72 |
+
|
73 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
74 |
+
return images
|
75 |
+
|
76 |
+
elif is_valid_image(images):
|
77 |
+
return [images]
|
78 |
+
|
79 |
+
raise ValueError(f"Could not make batched images from {images}")
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
83 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
84 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
85 |
+
return videos
|
86 |
+
|
87 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
88 |
+
if isinstance(videos[0], Image.Image):
|
89 |
+
return [videos]
|
90 |
+
elif len(videos[0].shape) == 4:
|
91 |
+
return [list(video) for video in videos]
|
92 |
+
|
93 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
94 |
+
return [list(videos)]
|
95 |
+
|
96 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
97 |
+
|
98 |
+
|
99 |
+
def smart_resize(
|
100 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
101 |
+
):
|
102 |
+
"""Rescales the image so that the following conditions are met:
|
103 |
+
|
104 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
105 |
+
|
106 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
107 |
+
|
108 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
109 |
+
|
110 |
+
"""
|
111 |
+
if height < factor or width < factor:
|
112 |
+
scale = factor / min(height, width)
|
113 |
+
width = round(scale * width)
|
114 |
+
height = round(scale * height)
|
115 |
+
elif max(height, width) / min(height, width) > 200:
|
116 |
+
raise ValueError(
|
117 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
118 |
+
)
|
119 |
+
h_bar = round(height / factor) * factor
|
120 |
+
w_bar = round(width / factor) * factor
|
121 |
+
if h_bar * w_bar > max_pixels:
|
122 |
+
beta = math.sqrt((height * width) / max_pixels)
|
123 |
+
h_bar = math.floor(height / beta / factor) * factor
|
124 |
+
w_bar = math.floor(width / beta / factor) * factor
|
125 |
+
elif h_bar * w_bar < min_pixels:
|
126 |
+
beta = math.sqrt(min_pixels / (height * width))
|
127 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
128 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
129 |
+
return h_bar, w_bar
|
130 |
+
|
131 |
+
|
132 |
+
class Qwen2VLImageProcessor(BaseImageProcessor):
|
133 |
+
r"""
|
134 |
+
Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
138 |
+
Whether to resize the image's (height, width) dimensions.
|
139 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
140 |
+
Resampling filter to use when resizing the image.
|
141 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
143 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
144 |
+
Scale factor to use if rescaling the image.
|
145 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
146 |
+
Whether to normalize the image.
|
147 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
148 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
149 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
150 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
151 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
152 |
+
Whether to convert the image to RGB.
|
153 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
154 |
+
The min pixels of the image to resize the image.
|
155 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
156 |
+
The max pixels of the image to resize the image.
|
157 |
+
patch_size (`int`, *optional*, defaults to 14):
|
158 |
+
The spacial patch size of the vision encoder.
|
159 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
160 |
+
The temporal patch size of the vision encoder.
|
161 |
+
merge_size (`int`, *optional*, defaults to 2):
|
162 |
+
The merge size of the vision encoder to llm encoder.
|
163 |
+
"""
|
164 |
+
|
165 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
do_resize: bool = True,
|
170 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
171 |
+
do_rescale: bool = True,
|
172 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
173 |
+
do_normalize: bool = True,
|
174 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
175 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
176 |
+
do_convert_rgb: bool = True,
|
177 |
+
min_pixels: int = 56 * 56,
|
178 |
+
max_pixels: int = 28 * 28 * 1280,
|
179 |
+
patch_size: int = 14,
|
180 |
+
temporal_patch_size: int = 2,
|
181 |
+
merge_size: int = 2,
|
182 |
+
**kwargs,
|
183 |
+
) -> None:
|
184 |
+
super().__init__(**kwargs)
|
185 |
+
self.do_resize = do_resize
|
186 |
+
self.resample = resample
|
187 |
+
self.do_rescale = do_rescale
|
188 |
+
self.rescale_factor = rescale_factor
|
189 |
+
self.do_normalize = do_normalize
|
190 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
191 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
192 |
+
self.min_pixels = min_pixels
|
193 |
+
self.max_pixels = max_pixels
|
194 |
+
self.patch_size = patch_size
|
195 |
+
self.temporal_patch_size = temporal_patch_size
|
196 |
+
self.merge_size = merge_size
|
197 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
198 |
+
self.do_convert_rgb = do_convert_rgb
|
199 |
+
|
200 |
+
def _preprocess(
|
201 |
+
self,
|
202 |
+
images: Union[ImageInput, VideoInput],
|
203 |
+
do_resize: bool = None,
|
204 |
+
resample: PILImageResampling = None,
|
205 |
+
do_rescale: bool = None,
|
206 |
+
rescale_factor: float = None,
|
207 |
+
do_normalize: bool = None,
|
208 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
209 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
210 |
+
do_convert_rgb: bool = None,
|
211 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
212 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
213 |
+
num_images: Optional[int] = 1,
|
214 |
+
image_downsampling: Optional[int] = None,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
images (`ImageInput`):
|
221 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
222 |
+
vision_info (`List[Dict]`, *optional*):
|
223 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
224 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
225 |
+
Whether to resize the image.
|
226 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
227 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
228 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
229 |
+
Whether to rescale the image.
|
230 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
231 |
+
Scale factor to use if rescaling the image.
|
232 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
233 |
+
Whether to normalize the image.
|
234 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
235 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
236 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
237 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
238 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
239 |
+
Whether to convert the image to RGB.
|
240 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
241 |
+
The channel dimension format for the output image. Can be one of:
|
242 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
243 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
244 |
+
- Unset: Use the channel dimension format of the input image.
|
245 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
246 |
+
The channel dimension format for the input image. Can be one of:
|
247 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
248 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
249 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
250 |
+
"""
|
251 |
+
images = make_list_of_images(images)
|
252 |
+
|
253 |
+
if do_convert_rgb:
|
254 |
+
images = [convert_to_rgb(image) for image in images]
|
255 |
+
|
256 |
+
# All transformations expect numpy arrays.
|
257 |
+
images = [to_numpy_array(image) for image in images]
|
258 |
+
|
259 |
+
if is_scaled_image(images[0]) and do_rescale:
|
260 |
+
logger.warning_once(
|
261 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
262 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
263 |
+
)
|
264 |
+
if input_data_format is None:
|
265 |
+
# We assume that all images have the same channel dimension format.
|
266 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
267 |
+
|
268 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
269 |
+
resized_height, resized_width = height, width
|
270 |
+
processed_images = []
|
271 |
+
for image in images:
|
272 |
+
if do_resize:
|
273 |
+
max_pixels = int(self.max_pixels / (self.merge_size / image_downsampling)**2)
|
274 |
+
resized_height, resized_width = smart_resize(
|
275 |
+
height,
|
276 |
+
width,
|
277 |
+
factor=self.patch_size * image_downsampling,
|
278 |
+
min_pixels=self.min_pixels,
|
279 |
+
max_pixels=int(max_pixels // num_images),
|
280 |
+
)
|
281 |
+
image = resize(
|
282 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
283 |
+
)
|
284 |
+
|
285 |
+
if do_rescale:
|
286 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
287 |
+
|
288 |
+
if do_normalize:
|
289 |
+
image = self.normalize(
|
290 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
291 |
+
)
|
292 |
+
|
293 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
294 |
+
processed_images.append(image)
|
295 |
+
|
296 |
+
patches = np.array(processed_images)
|
297 |
+
if data_format == ChannelDimension.LAST:
|
298 |
+
patches = patches.transpose(0, 3, 1, 2)
|
299 |
+
if patches.shape[0] == 1:
|
300 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
301 |
+
channel = patches.shape[1]
|
302 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
303 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
304 |
+
patches = patches.reshape(
|
305 |
+
grid_t,
|
306 |
+
self.temporal_patch_size,
|
307 |
+
channel,
|
308 |
+
grid_h // image_downsampling,
|
309 |
+
image_downsampling,
|
310 |
+
self.patch_size,
|
311 |
+
grid_w // image_downsampling,
|
312 |
+
image_downsampling,
|
313 |
+
self.patch_size,
|
314 |
+
)
|
315 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
316 |
+
flatten_patches = patches.reshape(
|
317 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
318 |
+
)
|
319 |
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
320 |
+
|
321 |
+
def preprocess(
|
322 |
+
self,
|
323 |
+
images: ImageInput,
|
324 |
+
videos: VideoInput = None,
|
325 |
+
do_resize: bool = None,
|
326 |
+
size: Dict[str, int] = None,
|
327 |
+
resample: PILImageResampling = None,
|
328 |
+
do_rescale: bool = None,
|
329 |
+
rescale_factor: float = None,
|
330 |
+
do_normalize: bool = None,
|
331 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
332 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
333 |
+
do_convert_rgb: bool = None,
|
334 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
335 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
336 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
337 |
+
num_images: Optional[int] = 1,
|
338 |
+
image_downsampling: Optional[int] = None,
|
339 |
+
):
|
340 |
+
"""
|
341 |
+
Args:
|
342 |
+
images (`ImageInput`):
|
343 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
344 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
345 |
+
videos (`VideoInput`):
|
346 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
347 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
348 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
349 |
+
Whether to resize the image.
|
350 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
351 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
352 |
+
the longest edge resized to keep the input aspect ratio.
|
353 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
354 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
355 |
+
has an effect if `do_resize` is set to `True`.
|
356 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
357 |
+
Whether to rescale the image.
|
358 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
359 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
360 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
361 |
+
Whether to normalize the image.
|
362 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
363 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
364 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
365 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
366 |
+
`True`.
|
367 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
368 |
+
Whether to convert the image to RGB.
|
369 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
370 |
+
The type of tensors to return. Can be one of:
|
371 |
+
- Unset: Return a list of `np.ndarray`.
|
372 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
373 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
374 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
375 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
376 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
377 |
+
The channel dimension format for the output image. Can be one of:
|
378 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
379 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
380 |
+
- Unset: Use the channel dimension format of the input image.
|
381 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
382 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
383 |
+
from the input image. Can be one of:
|
384 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
385 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
386 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
387 |
+
|
388 |
+
"""
|
389 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
390 |
+
size = size if size is not None else self.size
|
391 |
+
resample = resample if resample is not None else self.resample
|
392 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
393 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
394 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
395 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
396 |
+
image_std = image_std if image_std is not None else self.image_std
|
397 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
398 |
+
image_downsampling = image_downsampling if image_downsampling is not None else self.merge_size
|
399 |
+
|
400 |
+
if images is not None:
|
401 |
+
images = make_batched_images(images)
|
402 |
+
if videos is not None:
|
403 |
+
videos = make_batched_videos(videos)
|
404 |
+
|
405 |
+
if images is not None and not valid_images(images):
|
406 |
+
raise ValueError(
|
407 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
408 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
409 |
+
)
|
410 |
+
|
411 |
+
validate_preprocess_arguments(
|
412 |
+
rescale_factor=rescale_factor,
|
413 |
+
do_normalize=do_normalize,
|
414 |
+
image_mean=image_mean,
|
415 |
+
image_std=image_std,
|
416 |
+
do_resize=do_resize,
|
417 |
+
size=size,
|
418 |
+
resample=resample,
|
419 |
+
)
|
420 |
+
|
421 |
+
if images is not None:
|
422 |
+
pixel_values, vision_grid_thws = [], []
|
423 |
+
for image in images:
|
424 |
+
patches, image_grid_thw = self._preprocess(
|
425 |
+
image,
|
426 |
+
do_resize=do_resize,
|
427 |
+
resample=resample,
|
428 |
+
do_rescale=do_rescale,
|
429 |
+
rescale_factor=rescale_factor,
|
430 |
+
do_normalize=do_normalize,
|
431 |
+
image_mean=image_mean,
|
432 |
+
image_std=image_std,
|
433 |
+
data_format=data_format,
|
434 |
+
do_convert_rgb=do_convert_rgb,
|
435 |
+
input_data_format=input_data_format,
|
436 |
+
num_images=num_images,
|
437 |
+
image_downsampling=image_downsampling,
|
438 |
+
)
|
439 |
+
pixel_values.extend(patches)
|
440 |
+
vision_grid_thws.append(image_grid_thw)
|
441 |
+
pixel_values = np.array(pixel_values)
|
442 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
443 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
444 |
+
|
445 |
+
if videos is not None:
|
446 |
+
pixel_values, vision_grid_thws = [], []
|
447 |
+
for images in videos:
|
448 |
+
patches, video_grid_thw = self._preprocess(
|
449 |
+
images,
|
450 |
+
do_resize=do_resize,
|
451 |
+
resample=resample,
|
452 |
+
do_rescale=do_rescale,
|
453 |
+
rescale_factor=rescale_factor,
|
454 |
+
do_normalize=do_normalize,
|
455 |
+
image_mean=image_mean,
|
456 |
+
image_std=image_std,
|
457 |
+
data_format=data_format,
|
458 |
+
do_convert_rgb=do_convert_rgb,
|
459 |
+
input_data_format=input_data_format,
|
460 |
+
num_images=num_images,
|
461 |
+
image_downsampling=image_downsampling,
|
462 |
+
)
|
463 |
+
pixel_values.extend(patches)
|
464 |
+
vision_grid_thws.append(video_grid_thw)
|
465 |
+
pixel_values = np.array(pixel_values)
|
466 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
467 |
+
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
468 |
+
|
469 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
videollama3/model/qwen2vl_encoder/modeling_qwen2vl_encoder.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""PyTorch Qwen2-VL model."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
31 |
+
from transformers.activations import ACT2FN
|
32 |
+
from transformers.cache_utils import Cache, StaticCache
|
33 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
34 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.utils import (add_start_docstrings,
|
37 |
+
add_start_docstrings_to_model_forward,
|
38 |
+
is_flash_attn_2_available,
|
39 |
+
is_flash_attn_greater_or_equal_2_10, logging,
|
40 |
+
replace_return_docstrings)
|
41 |
+
|
42 |
+
from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig
|
43 |
+
|
44 |
+
if is_flash_attn_2_available():
|
45 |
+
from flash_attn import flash_attn_varlen_func
|
46 |
+
from transformers.modeling_flash_attention_utils import \
|
47 |
+
_flash_attention_forward
|
48 |
+
else:
|
49 |
+
flash_attn_varlen_func = None
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__)
|
52 |
+
|
53 |
+
|
54 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
55 |
+
def rotate_half(x):
|
56 |
+
"""Rotates half the hidden dims of the input."""
|
57 |
+
x1 = x[..., : x.shape[-1] // 2]
|
58 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
59 |
+
return torch.cat((-x2, x1), dim=-1)
|
60 |
+
|
61 |
+
|
62 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
63 |
+
orig_dtype = tensor.dtype
|
64 |
+
tensor = tensor.float()
|
65 |
+
cos = freqs.cos()
|
66 |
+
sin = freqs.sin()
|
67 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
68 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
69 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
70 |
+
output = output.to(orig_dtype)
|
71 |
+
return output
|
72 |
+
|
73 |
+
|
74 |
+
class VisionRotaryEmbedding(nn.Module):
|
75 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
76 |
+
super().__init__()
|
77 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
78 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
79 |
+
|
80 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
81 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
82 |
+
freqs = torch.outer(seq, self.inv_freq)
|
83 |
+
return freqs
|
84 |
+
|
85 |
+
|
86 |
+
class PatchEmbed(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
patch_size: int = 14,
|
90 |
+
temporal_patch_size: int = 2,
|
91 |
+
in_channels: int = 3,
|
92 |
+
embed_dim: int = 1152,
|
93 |
+
) -> None:
|
94 |
+
super().__init__()
|
95 |
+
self.patch_size = patch_size
|
96 |
+
self.temporal_patch_size = temporal_patch_size
|
97 |
+
self.in_channels = in_channels
|
98 |
+
self.embed_dim = embed_dim
|
99 |
+
|
100 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
101 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
102 |
+
|
103 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
104 |
+
target_dtype = self.proj.weight.dtype
|
105 |
+
hidden_states = hidden_states.view(
|
106 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
107 |
+
)
|
108 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
109 |
+
return hidden_states
|
110 |
+
|
111 |
+
|
112 |
+
class PatchMerger(nn.Module):
|
113 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
114 |
+
super().__init__()
|
115 |
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
116 |
+
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
117 |
+
self.mlp = nn.Sequential(
|
118 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
119 |
+
nn.GELU(),
|
120 |
+
nn.Linear(self.hidden_size, dim),
|
121 |
+
)
|
122 |
+
|
123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class VisionMlp(nn.Module):
|
129 |
+
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
|
130 |
+
super().__init__()
|
131 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
132 |
+
self.act = ACT2FN[hidden_act]
|
133 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
134 |
+
|
135 |
+
def forward(self, x) -> torch.Tensor:
|
136 |
+
return self.fc2(self.act(self.fc1(x)))
|
137 |
+
|
138 |
+
|
139 |
+
class VisionAttention(nn.Module):
|
140 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
141 |
+
super().__init__()
|
142 |
+
self.num_heads = num_heads
|
143 |
+
self.head_dim = dim // num_heads
|
144 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
145 |
+
self.proj = nn.Linear(dim, dim)
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
149 |
+
) -> torch.Tensor:
|
150 |
+
seq_length = hidden_states.shape[0]
|
151 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
152 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
153 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
154 |
+
|
155 |
+
attention_mask = torch.full(
|
156 |
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
157 |
+
)
|
158 |
+
for i in range(1, len(cu_seqlens)):
|
159 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
160 |
+
|
161 |
+
q = q.transpose(0, 1)
|
162 |
+
k = k.transpose(0, 1)
|
163 |
+
v = v.transpose(0, 1)
|
164 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
165 |
+
attn_weights = attn_weights + attention_mask
|
166 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
167 |
+
attn_output = torch.matmul(attn_weights, v)
|
168 |
+
attn_output = attn_output.transpose(0, 1)
|
169 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
170 |
+
attn_output = self.proj(attn_output)
|
171 |
+
return attn_output
|
172 |
+
|
173 |
+
|
174 |
+
class VisionFlashAttention2(nn.Module):
|
175 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
176 |
+
super().__init__()
|
177 |
+
self.num_heads = num_heads
|
178 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
179 |
+
self.proj = nn.Linear(dim, dim)
|
180 |
+
|
181 |
+
def forward(
|
182 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
183 |
+
) -> torch.Tensor:
|
184 |
+
seq_length = hidden_states.shape[0]
|
185 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
186 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
187 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
188 |
+
|
189 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
190 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
191 |
+
seq_length, -1
|
192 |
+
)
|
193 |
+
attn_output = self.proj(attn_output)
|
194 |
+
return attn_output
|
195 |
+
|
196 |
+
|
197 |
+
class VisionSdpaAttention(nn.Module):
|
198 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
199 |
+
super().__init__()
|
200 |
+
self.num_heads = num_heads
|
201 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
202 |
+
self.proj = nn.Linear(dim, dim)
|
203 |
+
|
204 |
+
def forward(
|
205 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
206 |
+
) -> torch.Tensor:
|
207 |
+
seq_length = hidden_states.shape[0]
|
208 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
209 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
210 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
211 |
+
|
212 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
213 |
+
for i in range(1, len(cu_seqlens)):
|
214 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
215 |
+
q = q.transpose(0, 1)
|
216 |
+
k = k.transpose(0, 1)
|
217 |
+
v = v.transpose(0, 1)
|
218 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
219 |
+
attn_output = attn_output.transpose(0, 1)
|
220 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
221 |
+
attn_output = self.proj(attn_output)
|
222 |
+
return attn_output
|
223 |
+
|
224 |
+
|
225 |
+
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
226 |
+
"eager": VisionAttention,
|
227 |
+
"flash_attention_2": VisionFlashAttention2,
|
228 |
+
"sdpa": VisionSdpaAttention,
|
229 |
+
}
|
230 |
+
|
231 |
+
|
232 |
+
class Qwen2VLVisionBlock(nn.Module):
|
233 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
234 |
+
super().__init__()
|
235 |
+
self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
|
236 |
+
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
|
237 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
238 |
+
|
239 |
+
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
240 |
+
config.embed_dim, num_heads=config.num_heads
|
241 |
+
)
|
242 |
+
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
243 |
+
|
244 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
245 |
+
hidden_states = hidden_states + self.attn(
|
246 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
247 |
+
)
|
248 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
249 |
+
return hidden_states
|
250 |
+
|
251 |
+
|
252 |
+
class Qwen2VLPreTrainedModel(PreTrainedModel):
|
253 |
+
config_class = Qwen2VLVisionConfig
|
254 |
+
base_model_prefix = "model"
|
255 |
+
supports_gradient_checkpointing = True
|
256 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
257 |
+
_skip_keys_device_placement = "past_key_values"
|
258 |
+
_supports_flash_attn_2 = True
|
259 |
+
_supports_sdpa = True
|
260 |
+
_supports_cache_class = True
|
261 |
+
_supports_static_cache = True
|
262 |
+
|
263 |
+
def _init_weights(self, module):
|
264 |
+
std = self.config.initializer_range
|
265 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
266 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
267 |
+
if module.bias is not None:
|
268 |
+
module.bias.data.zero_()
|
269 |
+
elif isinstance(module, nn.Embedding):
|
270 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
271 |
+
if module.padding_idx is not None:
|
272 |
+
module.weight.data[module.padding_idx].zero_()
|
273 |
+
|
274 |
+
|
275 |
+
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
276 |
+
config_class = Qwen2VLVisionConfig
|
277 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
278 |
+
|
279 |
+
def __init__(self, config) -> None:
|
280 |
+
super().__init__(config)
|
281 |
+
self.spatial_merge_size = config.spatial_merge_size
|
282 |
+
self.gradient_checkpointing = False
|
283 |
+
|
284 |
+
self.patch_embed = PatchEmbed(
|
285 |
+
patch_size=config.patch_size,
|
286 |
+
temporal_patch_size=config.temporal_patch_size,
|
287 |
+
in_channels=config.in_channels,
|
288 |
+
embed_dim=config.embed_dim,
|
289 |
+
)
|
290 |
+
|
291 |
+
head_dim = config.embed_dim // config.num_heads
|
292 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
293 |
+
|
294 |
+
self.blocks = nn.ModuleList(
|
295 |
+
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
296 |
+
)
|
297 |
+
#
|
298 |
+
# if self.spatial_merge_size > 1:
|
299 |
+
# self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
|
300 |
+
|
301 |
+
def get_dtype(self) -> torch.dtype:
|
302 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
303 |
+
|
304 |
+
def get_device(self) -> torch.device:
|
305 |
+
return self.blocks[0].mlp.fc2.weight.device
|
306 |
+
|
307 |
+
def rot_pos_emb(self, grid_thw, strides):
|
308 |
+
pos_ids = []
|
309 |
+
for (t, h, w), stride in zip(grid_thw, strides):
|
310 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
311 |
+
hpos_ids = hpos_ids.reshape(
|
312 |
+
h // stride,
|
313 |
+
stride,
|
314 |
+
w // stride,
|
315 |
+
stride,
|
316 |
+
)
|
317 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
318 |
+
hpos_ids = hpos_ids.flatten()
|
319 |
+
|
320 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
321 |
+
wpos_ids = wpos_ids.reshape(
|
322 |
+
h // stride,
|
323 |
+
stride,
|
324 |
+
w // stride,
|
325 |
+
stride,
|
326 |
+
)
|
327 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
328 |
+
wpos_ids = wpos_ids.flatten()
|
329 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
330 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
331 |
+
max_grid_size = grid_thw[:, 1:].max()
|
332 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
333 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
334 |
+
return rotary_pos_emb
|
335 |
+
|
336 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
337 |
+
hidden_states = self.patch_embed(hidden_states)
|
338 |
+
|
339 |
+
# BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx`
|
340 |
+
# rotary_pos_emb = []
|
341 |
+
# for thw in grid_thws:
|
342 |
+
# rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0))
|
343 |
+
# rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0)
|
344 |
+
# grid_thws = torch.cat(grid_thws, dim = 0)
|
345 |
+
|
346 |
+
# new version of creating rotary position embedding
|
347 |
+
# grid_thws shapes like [batch_flatten_image_num, 3]
|
348 |
+
# grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py`
|
349 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thws, strides)
|
350 |
+
|
351 |
+
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
352 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
353 |
+
|
354 |
+
for blk in self.blocks:
|
355 |
+
if self.gradient_checkpointing and self.training:
|
356 |
+
hidden_states = self._gradient_checkpointing_func(
|
357 |
+
blk.__call__,
|
358 |
+
hidden_states,
|
359 |
+
cu_seqlens,
|
360 |
+
rotary_pos_emb
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
364 |
+
|
365 |
+
# if self.spatial_merge_size > 1:
|
366 |
+
# hidden_states = self.merger(hidden_states)
|
367 |
+
return hidden_states
|