Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- .DS_Store +0 -0
- app.py +562 -0
- requirements.txt +48 -0
- videollama3/.DS_Store +0 -0
- videollama3/__init__.py +239 -0
- videollama3/constants.py +46 -0
- videollama3/infer.py +82 -0
- videollama3/mm_utils.py +704 -0
- videollama3/model/__init__.py +166 -0
- videollama3/model/__pycache__/__init__.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/encoder.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/processor.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/projector.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/region_encoder.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/videollama3_arch.cpython-310.pyc +0 -0
- videollama3/model/__pycache__/videollama3_qwen2.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__init__.py +3 -0
- videollama3/model/damovl_encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__pycache__/configuration_damovl_encoder.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__pycache__/image_processing.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/__pycache__/modeling_damovl_encoder.cpython-310.pyc +0 -0
- videollama3/model/damovl_encoder/configuration_damovl_encoder.py +71 -0
- videollama3/model/damovl_encoder/image_processing.py +472 -0
- videollama3/model/damovl_encoder/modeling_damovl_encoder.py +542 -0
- videollama3/model/encoder.py +385 -0
- videollama3/model/processor.py +366 -0
- videollama3/model/projector.py +160 -0
- videollama3/model/qwen2vl_encoder/__init__.py +3 -0
- videollama3/model/qwen2vl_encoder/__pycache__/__init__.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/__pycache__/configuration_qwen2vl_encoder.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/__pycache__/image_processing.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/__pycache__/modeling_qwen2vl_encoder.cpython-310.pyc +0 -0
- videollama3/model/qwen2vl_encoder/configuration_qwen2vl_encoder.py +72 -0
- videollama3/model/qwen2vl_encoder/image_processing.py +469 -0
- videollama3/model/qwen2vl_encoder/modeling_qwen2vl_encoder.py +367 -0
- videollama3/model/region_encoder.py +117 -0
- videollama3/model/videollama3_arch.py +422 -0
- videollama3/model/videollama3_qwen2.py +163 -0
- videollama3/train.py +798 -0
- videollama3/videollama3_trainer.py +398 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from transformers import SamModel, SamProcessor
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import cv2
|
8 |
+
import argparse
|
9 |
+
import sys
|
10 |
+
# This is for making model initialization faster and has no effect since we are loading the weights
|
11 |
+
sys.path.append('./')
|
12 |
+
from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
|
13 |
+
from videollama3.mm_utils import load_images
|
14 |
+
from videollama3.mm_utils import load_video
|
15 |
+
|
16 |
+
|
17 |
+
color_rgb = (1.0, 1.0, 1.0)
|
18 |
+
color_rgbs = [
|
19 |
+
(1.0, 1.0, 1.0),
|
20 |
+
(1.0, 0.0, 0.0),
|
21 |
+
(0.0, 1.0, 1.0),
|
22 |
+
(0.0, 1.0, 0.0),
|
23 |
+
(0.0, 0.0, 1.0),
|
24 |
+
(1.0, 0.0, 1.0),
|
25 |
+
]
|
26 |
+
|
27 |
+
mask_list = []
|
28 |
+
mask_raw_list = []
|
29 |
+
mask_list_video = []
|
30 |
+
mask_raw_list_video = []
|
31 |
+
|
32 |
+
def extract_first_frame_from_video(video):
|
33 |
+
cap = cv2.VideoCapture(video)
|
34 |
+
success, frame = cap.read()
|
35 |
+
cap.release()
|
36 |
+
if success:
|
37 |
+
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
38 |
+
return None
|
39 |
+
|
40 |
+
def extract_points_from_mask(mask_pil):
|
41 |
+
mask = np.asarray(mask_pil)[..., 0]
|
42 |
+
coords = np.nonzero(mask)
|
43 |
+
coords = np.stack((coords[1], coords[0]), axis=1)
|
44 |
+
|
45 |
+
return coords
|
46 |
+
|
47 |
+
def add_contour(img, mask, color=(1., 1., 1.)):
|
48 |
+
img = img.copy()
|
49 |
+
|
50 |
+
mask = mask.astype(np.uint8) * 255
|
51 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
52 |
+
cv2.drawContours(img, contours, -1, color, thickness=8)
|
53 |
+
|
54 |
+
return img
|
55 |
+
|
56 |
+
def generate_masks(image):
|
57 |
+
global mask_list
|
58 |
+
global mask_raw_list
|
59 |
+
image['image'] = image['background'].convert('RGB')
|
60 |
+
# del image['background'], image['composite']
|
61 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
62 |
+
|
63 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
64 |
+
points = extract_points_from_mask(mask)
|
65 |
+
np.random.seed(0)
|
66 |
+
if points.shape[0] == 0:
|
67 |
+
raise gr.Error("No points selected")
|
68 |
+
|
69 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
70 |
+
points = points[points_selected_indices]
|
71 |
+
coords = [points.tolist()]
|
72 |
+
mask_np = apply_sam(image['image'], coords)
|
73 |
+
|
74 |
+
mask_raw_list.append(mask_np)
|
75 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
|
76 |
+
|
77 |
+
mask_list.append((mask_image, f"<region{len(mask_list)}>"))
|
78 |
+
# Return a list containing the mask image.
|
79 |
+
image['layers'] = []
|
80 |
+
image['composite'] = image['background']
|
81 |
+
return mask_list, image
|
82 |
+
|
83 |
+
|
84 |
+
def generate_masks_video(image):
|
85 |
+
global mask_list_video
|
86 |
+
global mask_raw_list_video
|
87 |
+
image['image'] = image['background'].convert('RGB')
|
88 |
+
# del image['background'], image['composite']
|
89 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
90 |
+
|
91 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
92 |
+
points = extract_points_from_mask(mask)
|
93 |
+
np.random.seed(0)
|
94 |
+
if points.shape[0] == 0:
|
95 |
+
raise gr.Error("No points selected")
|
96 |
+
|
97 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
98 |
+
points = points[points_selected_indices]
|
99 |
+
coords = [points.tolist()]
|
100 |
+
mask_np = apply_sam(image['image'], coords)
|
101 |
+
|
102 |
+
mask_raw_list_video.append(mask_np)
|
103 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(image['image'])).astype(np.uint8))
|
104 |
+
|
105 |
+
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
106 |
+
# Return a list containing the mask image.
|
107 |
+
image['layers'] = []
|
108 |
+
image['composite'] = image['background']
|
109 |
+
return mask_list_video, image
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
def describe(image, mode, query, masks):
|
114 |
+
# Create an image object from the uploaded image
|
115 |
+
# print(image.keys())
|
116 |
+
|
117 |
+
image['image'] = image['background'].convert('RGB')
|
118 |
+
# del image['background'], image['composite']
|
119 |
+
assert len(image['layers']) == 1, f"Expected 1 layer, got {len(image['layers'])}"
|
120 |
+
|
121 |
+
# Handle both hex and rgba color formats
|
122 |
+
|
123 |
+
img_np = np.asarray(image['image']).astype(float) / 255.
|
124 |
+
if mode=='Caption':
|
125 |
+
mask = Image.fromarray((np.asarray(image['layers'][0])[..., 3] > 0).astype(np.uint8) * 255).convert('RGB')
|
126 |
+
|
127 |
+
points = extract_points_from_mask(mask)
|
128 |
+
|
129 |
+
np.random.seed(0)
|
130 |
+
|
131 |
+
if points.shape[0] == 0:
|
132 |
+
if len(masks)>1:
|
133 |
+
raise gr.Error("No points selected")
|
134 |
+
|
135 |
+
else:
|
136 |
+
# Randomly sample 8 points from the mask
|
137 |
+
# Follow DAM https://github.com/NVlabs/describe-anything
|
138 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
139 |
+
points = points[points_selected_indices]
|
140 |
+
|
141 |
+
coords = [points.tolist()]
|
142 |
+
|
143 |
+
mask_np = apply_sam(image['image'], coords)
|
144 |
+
|
145 |
+
masks = []
|
146 |
+
masks.append(mask_np)
|
147 |
+
mask_ids = [0]
|
148 |
+
|
149 |
+
img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
|
150 |
+
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
151 |
+
else:
|
152 |
+
masks = mask_raw_list
|
153 |
+
img_with_contour_np = img_np.copy()
|
154 |
+
|
155 |
+
mask_ids = []
|
156 |
+
for i, mask_np in enumerate(masks):
|
157 |
+
img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
|
158 |
+
img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
159 |
+
mask_ids.append(0)
|
160 |
+
|
161 |
+
masks = np.stack(masks, axis=0)
|
162 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
img = np.asarray(image['image'])
|
167 |
+
|
168 |
+
|
169 |
+
if mode == "Caption":
|
170 |
+
query = '<image>\nPlease describe the <region> in the image in detail.'
|
171 |
+
else:
|
172 |
+
if len(masks)==1:
|
173 |
+
prefix = "<image>\nThere is 1 region in the image: <region0> <region>. "
|
174 |
+
else:
|
175 |
+
prefix = f"<image>\nThere is {len(masks)} region in the image: "
|
176 |
+
for i in range(len(masks)):
|
177 |
+
prefix += f"<region{i}><region>, "
|
178 |
+
prefix = prefix[:-2]+'. '
|
179 |
+
query = prefix + query
|
180 |
+
# print(query)
|
181 |
+
|
182 |
+
image['layers'] = []
|
183 |
+
image['composite'] = image['background']
|
184 |
+
|
185 |
+
text = ""
|
186 |
+
yield img_with_contour_pil, text, image
|
187 |
+
|
188 |
+
for token in get_model_output(
|
189 |
+
[img],
|
190 |
+
query,
|
191 |
+
model=model,
|
192 |
+
tokenizer=tokenizer,
|
193 |
+
masks=masks,
|
194 |
+
mask_ids=mask_ids,
|
195 |
+
modal='image',
|
196 |
+
image_downsampling=1,
|
197 |
+
streaming=True,
|
198 |
+
):
|
199 |
+
text += token
|
200 |
+
yield gr.update(), text, gr.update()
|
201 |
+
|
202 |
+
|
203 |
+
def load_first_frame(video_path):
|
204 |
+
cap = cv2.VideoCapture(video_path)
|
205 |
+
ret, frame = cap.read()
|
206 |
+
cap.release()
|
207 |
+
if not ret:
|
208 |
+
raise gr.Error("Could not read the video file.")
|
209 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
210 |
+
image = Image.fromarray(frame)
|
211 |
+
return image
|
212 |
+
|
213 |
+
def describe_video(video_path, mode, query, annotated_frame, masks):
|
214 |
+
global mask_list_video
|
215 |
+
# Create a temporary directory to save extracted video frames
|
216 |
+
cap = cv2.VideoCapture(video_path)
|
217 |
+
|
218 |
+
video_tensor = load_video(video_path, fps=4, max_frames=768, frame_ids=[0])
|
219 |
+
|
220 |
+
annotated_frame['image'] = annotated_frame['background'].convert('RGB')
|
221 |
+
|
222 |
+
# Process the annotated frame from the image editor
|
223 |
+
if isinstance(annotated_frame, dict):
|
224 |
+
# Get the composite image with annotations
|
225 |
+
frame_img = annotated_frame.get("image", annotated_frame.get("background"))
|
226 |
+
if frame_img is None:
|
227 |
+
raise gr.Error("No valid annotation found in the image editor.")
|
228 |
+
frame_img = frame_img.convert("RGB")
|
229 |
+
|
230 |
+
# Get the annotation layer
|
231 |
+
if "layers" in annotated_frame and len(annotated_frame["layers"]) > 0:
|
232 |
+
mask = Image.fromarray((np.asarray(annotated_frame["layers"][0])[..., 3] > 0).astype(np.uint8) * 255).convert("RGB")
|
233 |
+
else:
|
234 |
+
mask = Image.new("RGB", frame_img.size, 0)
|
235 |
+
else:
|
236 |
+
frame_img = annotated_frame.convert("RGB")
|
237 |
+
mask = Image.new("RGB", frame_img.size, 0)
|
238 |
+
|
239 |
+
img_np = np.asarray(annotated_frame['image']).astype(float) / 255.
|
240 |
+
# Extract points from the annotated mask (using the first channel)
|
241 |
+
if mode == "Caption":
|
242 |
+
points = extract_points_from_mask(mask)
|
243 |
+
np.random.seed(0)
|
244 |
+
if points.shape[0] == 0:
|
245 |
+
raise gr.Error("No points were selected in the annotation.")
|
246 |
+
# Randomly select up to 8 points
|
247 |
+
# Follow DAM https://github.com/NVlabs/describe-anything
|
248 |
+
points_selected_indices = np.random.choice(points.shape[0], size=min(points.shape[0], 8), replace=False)
|
249 |
+
points = points[points_selected_indices]
|
250 |
+
|
251 |
+
# print(f"Selected points (to SAM): {points}")
|
252 |
+
|
253 |
+
coords = [points.tolist()]
|
254 |
+
|
255 |
+
mask_np = apply_sam(annotated_frame['image'], coords)
|
256 |
+
|
257 |
+
masks = []
|
258 |
+
masks.append(mask_np)
|
259 |
+
mask_ids = [0]
|
260 |
+
|
261 |
+
# img_with_contour_np = add_contour(img_np, mask_np, color=color_rgb)
|
262 |
+
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
263 |
+
|
264 |
+
|
265 |
+
else:
|
266 |
+
masks = mask_raw_list_video
|
267 |
+
img_with_contour_np = img_np.copy()
|
268 |
+
|
269 |
+
mask_ids = []
|
270 |
+
for i, mask_np in enumerate(masks):
|
271 |
+
# img_with_contour_np = add_contour(img_with_contour_np, mask_np, color=color_rgbs[i])
|
272 |
+
# img_with_contour_pil = Image.fromarray((img_with_contour_np * 255.).astype(np.uint8))
|
273 |
+
mask_ids.append(0)
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
masks = np.stack(masks, axis=0)
|
278 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
279 |
+
|
280 |
+
|
281 |
+
|
282 |
+
|
283 |
+
if mode == "Caption":
|
284 |
+
query = '<video>\nPlease describe the <region> in the video in detail.'
|
285 |
+
else:
|
286 |
+
if len(masks)==1:
|
287 |
+
prefix = "<video>\nThere is 1 object in the video: <object0> <region>. "
|
288 |
+
else:
|
289 |
+
prefix = f"<video>\nThere is {len(masks)} objects in the video: "
|
290 |
+
for i in range(len(masks)):
|
291 |
+
prefix += f"<object{i}><region>, "
|
292 |
+
prefix = prefix[:-2]+'. '
|
293 |
+
query = prefix + query
|
294 |
+
|
295 |
+
# Initialize empty text
|
296 |
+
# text = description_generator
|
297 |
+
annotated_frame['layers'] = []
|
298 |
+
annotated_frame['composite'] = annotated_frame['background']
|
299 |
+
|
300 |
+
if mode=="Caption":
|
301 |
+
mask_list_video = []
|
302 |
+
mask_image = Image.fromarray((mask_np[:,:,np.newaxis] * np.array(annotated_frame['image'])).astype(np.uint8))
|
303 |
+
mask_list_video.append((mask_image, f"<object{len(mask_list_video)}>"))
|
304 |
+
text = ""
|
305 |
+
yield frame_img, text, mask_list_video
|
306 |
+
|
307 |
+
for token in get_model_output(
|
308 |
+
video_tensor,
|
309 |
+
query,
|
310 |
+
model=model,
|
311 |
+
tokenizer=tokenizer,
|
312 |
+
masks=masks,
|
313 |
+
mask_ids=mask_ids,
|
314 |
+
modal='video',
|
315 |
+
streaming=True,
|
316 |
+
):
|
317 |
+
text += token
|
318 |
+
yield gr.update(), text, gr.update()
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
def apply_sam(image, input_points):
|
323 |
+
inputs = sam_processor(image, input_points=input_points, return_tensors="pt").to(device)
|
324 |
+
|
325 |
+
with torch.no_grad():
|
326 |
+
outputs = sam_model(**inputs)
|
327 |
+
|
328 |
+
masks = sam_processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())[0][0]
|
329 |
+
scores = outputs.iou_scores[0, 0]
|
330 |
+
|
331 |
+
mask_selection_index = scores.argmax()
|
332 |
+
|
333 |
+
mask_np = masks[mask_selection_index].numpy()
|
334 |
+
|
335 |
+
return mask_np
|
336 |
+
|
337 |
+
def clear_masks():
|
338 |
+
global mask_list
|
339 |
+
global mask_raw_list
|
340 |
+
mask_list = []
|
341 |
+
mask_raw_list = []
|
342 |
+
return []
|
343 |
+
|
344 |
+
|
345 |
+
def clear_masks_video():
|
346 |
+
global mask_list_video
|
347 |
+
global mask_raw_list_video
|
348 |
+
mask_list_video = []
|
349 |
+
mask_raw_list_video = []
|
350 |
+
return []
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
parser = argparse.ArgumentParser(description="VideoRefer gradio demo")
|
355 |
+
parser.add_argument("--model-path", type=str, default="DAMO-NLP-SG/VideoRefer-VideoLLaMA3-7B", help="Path to the model checkpoint")
|
356 |
+
parser.add_argument("--prompt-mode", type=str, default="focal_prompt", help="Prompt mode")
|
357 |
+
parser.add_argument("--conv-mode", type=str, default="v1", help="Conversation mode")
|
358 |
+
parser.add_argument("--temperature", type=float, default=0.2, help="Sampling temperature")
|
359 |
+
parser.add_argument("--top_p", type=float, default=0.5, help="Top-p for sampling")
|
360 |
+
|
361 |
+
args_cli = parser.parse_args()
|
362 |
+
print(args_cli.model_path)
|
363 |
+
|
364 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="amber")) as demo:
|
365 |
+
|
366 |
+
HEADER = ("""
|
367 |
+
<div>
|
368 |
+
<h1>VideoRefer X VideoLLaMA3 Demo</h1>
|
369 |
+
<h5 style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
370 |
+
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
|
371 |
+
</div>
|
372 |
+
</div>
|
373 |
+
<div style="display: flex; justify-content: left; margin-top: 10px;">
|
374 |
+
<a href="https://arxiv.org/pdf/2501.00599"><img src="https://img.shields.io/badge/Arxiv-2501.00599-ECA8A7" style="margin-right: 5px;"></a>
|
375 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoRefer"><img src='https://img.shields.io/badge/Github-VideoRefer-F7C97E' style="margin-right: 5px;"></a>
|
376 |
+
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9DC3E6' style="margin-right: 5px;"></a>
|
377 |
+
</div>
|
378 |
+
""")
|
379 |
+
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Column():
|
382 |
+
gr.HTML(HEADER)
|
383 |
+
|
384 |
+
|
385 |
+
image_tips = """
|
386 |
+
### 💡 Tips:
|
387 |
+
|
388 |
+
🧸 Upload an image, and you can use the drawing tool✍️ to highlight the areas you're interested in.
|
389 |
+
|
390 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
391 |
+
|
392 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
393 |
+
|
394 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
395 |
+
|
396 |
+
"""
|
397 |
+
|
398 |
+
video_tips = """
|
399 |
+
### 💡 Tips:
|
400 |
+
⚠️ For video mode, we only support masking on the first frame in this demo.
|
401 |
+
|
402 |
+
🧸 Upload an video, and you can use the drawing tool✍️ to highlight the areas you're interested in the first frame.
|
403 |
+
|
404 |
+
🔖 For single-object caption mode, simply select the area and click the 'Generate Caption' button to receive a caption for the object.
|
405 |
+
|
406 |
+
🔔 In QA mode, you can generate multiple masks by clicking the 'Generate Mask' button multiple times. Afterward, use the corresponding object id to ask questions.
|
407 |
+
|
408 |
+
📌 Click the button 'Clear Masks' to clear the current generated masks.
|
409 |
+
|
410 |
+
"""
|
411 |
+
|
412 |
+
|
413 |
+
with gr.TabItem("Image"):
|
414 |
+
with gr.Row():
|
415 |
+
with gr.Column():
|
416 |
+
image_input = gr.ImageEditor(
|
417 |
+
label="Image",
|
418 |
+
type="pil",
|
419 |
+
sources=['upload'],
|
420 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
421 |
+
eraser=True,
|
422 |
+
layers=False,
|
423 |
+
transforms=[],
|
424 |
+
height=300,
|
425 |
+
)
|
426 |
+
generate_mask_btn = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
427 |
+
mode = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
428 |
+
query = gr.Textbox(label="Question", value="What is the relationship between <region0> and <region1>?", interactive=True, visible=False)
|
429 |
+
|
430 |
+
submit_btn = gr.Button("Generate Caption", variant="primary")
|
431 |
+
submit_btn1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
432 |
+
gr.Examples([f"./demo/images/{i+1}.jpg" for i in range(8)], inputs=image_input, label="Examples")
|
433 |
+
|
434 |
+
with gr.Column():
|
435 |
+
mask_output = gr.Gallery(label="Referred Masks", object_fit='scale-down', visible=False)
|
436 |
+
output_image = gr.Image(label="Image with Mask", visible=True, height=400)
|
437 |
+
description = gr.Textbox(label="Output", visible=True)
|
438 |
+
|
439 |
+
clear_masks_btn = gr.Button("Clear Masks", variant="secondary", visible=False)
|
440 |
+
gr.Markdown(image_tips)
|
441 |
+
|
442 |
+
with gr.TabItem("Video"):
|
443 |
+
with gr.Row():
|
444 |
+
with gr.Column():
|
445 |
+
video_input = gr.Video(label="Video")
|
446 |
+
# load_btn = gr.Button("🖼️ Load First Frame", variant="secondary")
|
447 |
+
first_frame = gr.ImageEditor(
|
448 |
+
label="Annotate First Frame",
|
449 |
+
type="pil",
|
450 |
+
sources=['upload'],
|
451 |
+
brush=gr.Brush(colors=["#ED7D31"], color_mode="fixed", default_size=10),
|
452 |
+
eraser=True,
|
453 |
+
layers=False,
|
454 |
+
transforms=[],
|
455 |
+
height=300,
|
456 |
+
)
|
457 |
+
generate_mask_btn_video = gr.Button("1️⃣ Generate Mask", visible=False, variant="primary")
|
458 |
+
gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
|
459 |
+
|
460 |
+
with gr.Column():
|
461 |
+
mode_video = gr.Radio(label="Mode", choices=["Caption", "QA"], value="Caption")
|
462 |
+
mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
|
463 |
+
|
464 |
+
query_video = gr.Textbox(label="Question", value="What is the relationship between <object0> and <object1>?", interactive=True, visible=False)
|
465 |
+
|
466 |
+
submit_btn_video = gr.Button("Generate Caption", variant="primary")
|
467 |
+
submit_btn_video1 = gr.Button("2️⃣ Generate Answer", variant="primary", visible=False)
|
468 |
+
description_video = gr.Textbox(label="Output", visible=True)
|
469 |
+
|
470 |
+
clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
|
471 |
+
|
472 |
+
gr.Markdown(video_tips)
|
473 |
+
|
474 |
+
|
475 |
+
def toggle_query_and_generate_button(mode):
|
476 |
+
query_visible = mode == "QA"
|
477 |
+
caption_visible = mode == "Caption"
|
478 |
+
global mask_list
|
479 |
+
global mask_raw_list
|
480 |
+
mask_list = []
|
481 |
+
mask_raw_list = []
|
482 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), gr.update(visible=caption_visible), [], ""
|
483 |
+
|
484 |
+
video_input.change(load_first_frame, inputs=video_input, outputs=first_frame)
|
485 |
+
|
486 |
+
mode.change(toggle_query_and_generate_button, inputs=mode, outputs=[query, generate_mask_btn, clear_masks_btn, submit_btn1, mask_output, output_image, submit_btn, mask_output, description])
|
487 |
+
|
488 |
+
def toggle_query_and_generate_button_video(mode):
|
489 |
+
query_visible = mode == "QA"
|
490 |
+
caption_visible = mode == "Caption"
|
491 |
+
global mask_list_video
|
492 |
+
global mask_raw_list_video
|
493 |
+
mask_list_video = []
|
494 |
+
mask_raw_list_video = []
|
495 |
+
return gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=query_visible), gr.update(visible=caption_visible), []
|
496 |
+
|
497 |
+
|
498 |
+
mode_video.change(toggle_query_and_generate_button_video, inputs=mode_video, outputs=[query_video, generate_mask_btn_video, submit_btn_video1, submit_btn_video, mask_output_video])
|
499 |
+
|
500 |
+
submit_btn.click(
|
501 |
+
fn=describe,
|
502 |
+
inputs=[image_input, mode, query],
|
503 |
+
outputs=[output_image, description, image_input],
|
504 |
+
api_name="describe"
|
505 |
+
)
|
506 |
+
|
507 |
+
submit_btn1.click(
|
508 |
+
fn=describe,
|
509 |
+
inputs=[image_input, mode, query],
|
510 |
+
outputs=[output_image, description, image_input],
|
511 |
+
api_name="describe"
|
512 |
+
)
|
513 |
+
|
514 |
+
generate_mask_btn.click(
|
515 |
+
fn=generate_masks,
|
516 |
+
inputs=[image_input],
|
517 |
+
outputs=[mask_output, image_input]
|
518 |
+
)
|
519 |
+
|
520 |
+
generate_mask_btn_video.click(
|
521 |
+
fn=generate_masks_video,
|
522 |
+
inputs=[first_frame],
|
523 |
+
outputs=[mask_output_video, first_frame]
|
524 |
+
)
|
525 |
+
|
526 |
+
clear_masks_btn.click(
|
527 |
+
fn=clear_masks,
|
528 |
+
outputs=[mask_output]
|
529 |
+
)
|
530 |
+
|
531 |
+
clear_masks_btn_video.click(
|
532 |
+
fn=clear_masks_video,
|
533 |
+
outputs=[mask_output_video]
|
534 |
+
)
|
535 |
+
|
536 |
+
submit_btn_video.click(
|
537 |
+
fn=describe_video,
|
538 |
+
inputs=[video_input, mode_video, query_video, first_frame],
|
539 |
+
outputs=[first_frame, description_video, mask_output_video],
|
540 |
+
api_name="describe_video"
|
541 |
+
)
|
542 |
+
|
543 |
+
submit_btn_video1.click(
|
544 |
+
fn=describe_video,
|
545 |
+
inputs=[video_input, mode_video, query_video, first_frame],
|
546 |
+
outputs=[first_frame, description_video, mask_output_video],
|
547 |
+
api_name="describe_video"
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
|
552 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
553 |
+
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
554 |
+
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
555 |
+
|
556 |
+
disable_torch_init()
|
557 |
+
|
558 |
+
|
559 |
+
model, processor, tokenizer = model_init(args_cli.model_path)
|
560 |
+
|
561 |
+
|
562 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
4 |
+
# basic dependencies
|
5 |
+
torch==2.4.0
|
6 |
+
torchvision==0.19.0
|
7 |
+
datasets==2.21.0
|
8 |
+
transformers==4.46.3
|
9 |
+
tokenizers==0.20.3
|
10 |
+
deepspeed==0.15.4
|
11 |
+
accelerate==1.0.1
|
12 |
+
peft==0.4.0
|
13 |
+
timm==1.0.3
|
14 |
+
numpy==1.24.4
|
15 |
+
# data processing
|
16 |
+
decord==0.6.0
|
17 |
+
imageio==2.34.0
|
18 |
+
imageio-ffmpeg==0.4.9
|
19 |
+
moviepy==1.0.3
|
20 |
+
scenedetect==0.6.3
|
21 |
+
opencv-python==4.6.0.66
|
22 |
+
pyarrow
|
23 |
+
pysubs2
|
24 |
+
ffmpeg-python
|
25 |
+
# misc
|
26 |
+
scikit-learn==1.2.2
|
27 |
+
huggingface_hub==0.23.4
|
28 |
+
sentencepiece==0.1.99
|
29 |
+
shortuuid
|
30 |
+
einops==0.6.1
|
31 |
+
einops-exts==0.0.4
|
32 |
+
bitsandbytes==0.43.3 # for cuda 124
|
33 |
+
pydantic>=2.0
|
34 |
+
markdown2[all]
|
35 |
+
gradio==5.34.0
|
36 |
+
gradio_client==1.10.3
|
37 |
+
httpx==0.24.1
|
38 |
+
requests
|
39 |
+
openai
|
40 |
+
uvicorn
|
41 |
+
fastapi
|
42 |
+
tensorboard
|
43 |
+
wandb
|
44 |
+
tabulate
|
45 |
+
Levenshtein
|
46 |
+
pycocotools==2.0.8
|
47 |
+
spaces
|
48 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
videollama3/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
videollama3/__init__.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
import shutil
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .model import load_pretrained_model
|
11 |
+
from .model.processor import Videollama3Processor
|
12 |
+
from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, resize_image_mask
|
13 |
+
from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN
|
14 |
+
from videollama3.constants import REGION_TOKEN
|
15 |
+
from transformers import TextIteratorStreamer
|
16 |
+
from threading import Thread
|
17 |
+
|
18 |
+
def disable_torch_init():
|
19 |
+
"""
|
20 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
21 |
+
"""
|
22 |
+
import torch
|
23 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
24 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
25 |
+
|
26 |
+
|
27 |
+
def model_init(model_path=None, **kwargs):
|
28 |
+
model_path = "DAMO-NLP-SG/VideoLLaMA2-7B" if model_path is None else model_path
|
29 |
+
model_name = get_model_name_from_path(model_path)
|
30 |
+
tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs)
|
31 |
+
|
32 |
+
if tokenizer.pad_token is None and tokenizer.unk_token is not None:
|
33 |
+
tokenizer.pad_token = tokenizer.unk_token
|
34 |
+
|
35 |
+
aspect_ratio = model.config.image_aspect_ratio if hasattr(model.config, "image_aspect_ratio") else "pad"
|
36 |
+
image_size = model.config.image_size if hasattr(model.config, "image_size") else 384
|
37 |
+
# NOTE: If num_frames is None, the frame sampling mode is "fps". If num_frames is not None, the frame sampling mode is "uniform".
|
38 |
+
# num_frames = model.config.num_frames
|
39 |
+
model.config.region_token_index = tokenizer.convert_tokens_to_ids(REGION_TOKEN)
|
40 |
+
processor = {
|
41 |
+
'image': load_images,
|
42 |
+
'video': load_video,
|
43 |
+
'text': None
|
44 |
+
}
|
45 |
+
|
46 |
+
return model, processor, tokenizer
|
47 |
+
|
48 |
+
|
49 |
+
def get_model_output(images_or_videos, instruct, model, tokenizer, modal='video', **kwargs):
|
50 |
+
streaming = kwargs.pop('streaming', False)
|
51 |
+
if streaming:
|
52 |
+
return mm_infer(images_or_videos, instruct, model, tokenizer, modal, streaming=True, **kwargs)
|
53 |
+
else:
|
54 |
+
output = mm_infer(images_or_videos, instruct, model, tokenizer, modal, streaming=False, **kwargs)
|
55 |
+
return next(output)
|
56 |
+
|
57 |
+
|
58 |
+
def mm_infer(images_or_videos, instruct, model, tokenizer, modal='video', **kwargs):
|
59 |
+
"""inference api of VideoLLaMA2 for video understanding.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
model: VideoLLaMA2 model.
|
63 |
+
images_or_videos (torch.Tensor): image tensor (1, C, H, W) / video tensor (T, C, H, W).
|
64 |
+
instruct (str): text instruction for understanding video.
|
65 |
+
tokenizer: tokenizer.
|
66 |
+
do_sample (bool): whether to sample.
|
67 |
+
modal (str): inference modality.
|
68 |
+
Returns:
|
69 |
+
str: response of the model.
|
70 |
+
"""
|
71 |
+
mask_ids = kwargs.pop('mask_ids', None)
|
72 |
+
masks = kwargs.pop('masks', None)
|
73 |
+
streaming = kwargs.pop('streaming', False)
|
74 |
+
if modal == 'image':
|
75 |
+
modal_token = DEFAULT_IMAGE_TOKEN
|
76 |
+
images = images_or_videos
|
77 |
+
additional_frames = images.copy()
|
78 |
+
timestamps = None
|
79 |
+
elif modal == 'video':
|
80 |
+
modal_token = DEFAULT_VIDEO_TOKEN
|
81 |
+
images, timestamps, additional_frames = images_or_videos
|
82 |
+
elif modal == 'text':
|
83 |
+
modal_token = ''
|
84 |
+
else:
|
85 |
+
raise ValueError(f"Unsupported modal: {modal}")
|
86 |
+
|
87 |
+
vlprocessor = Videollama3Processor(model.get_vision_encoder().image_processor, tokenizer)
|
88 |
+
vlprocessor.tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN, STREAM_START_TOKEN, STREAM_END_TOKEN], special_tokens=True)
|
89 |
+
|
90 |
+
model.config.image_token_index = vlprocessor.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
|
91 |
+
|
92 |
+
if masks is not None:
|
93 |
+
additional_frames, masks, mask_nums = resize_image_mask(additional_frames, masks, mask_ids)
|
94 |
+
|
95 |
+
for idx in range(len(mask_nums)):
|
96 |
+
instruct = instruct.replace('<region>', "["+REGION_TOKEN*mask_nums[idx]+"]", 1)
|
97 |
+
|
98 |
+
|
99 |
+
additional_images_dict = vlprocessor._process_image(additional_frames, image_downsampling=1)
|
100 |
+
additional_images = additional_images_dict['images']
|
101 |
+
# import pdb
|
102 |
+
# pdb.set_trace()
|
103 |
+
|
104 |
+
|
105 |
+
# flatten_patches1 = additional_images[0].reshape(26, 46, 3, -1)
|
106 |
+
# from matplotlib import pyplot as plt
|
107 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
108 |
+
# plt.savefig('16.png')
|
109 |
+
|
110 |
+
additional_images_thws = additional_images_dict['grid_thws']
|
111 |
+
additional_images = (additional_images, additional_images_thws)
|
112 |
+
|
113 |
+
else:
|
114 |
+
additional_images = None
|
115 |
+
|
116 |
+
|
117 |
+
# 1. text preprocess (tag process & generate prompt).
|
118 |
+
if isinstance(instruct, str):
|
119 |
+
messages = [{'role': 'user', 'content': instruct}]
|
120 |
+
elif isinstance(instruct, list):
|
121 |
+
messages = copy.deepcopy(instruct)
|
122 |
+
else:
|
123 |
+
raise ValueError(f"Unsupported type of instruct: {type(instruct)}")
|
124 |
+
|
125 |
+
if all(not modal_token in message["content"] for message in messages):
|
126 |
+
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
|
127 |
+
messages[0]["content"] = modal_token + messages[0]["content"]
|
128 |
+
|
129 |
+
converted_messages = []
|
130 |
+
for message in messages:
|
131 |
+
chunks = message["content"].split(modal_token)
|
132 |
+
converted_messages.append({
|
133 |
+
"role": "user",
|
134 |
+
"content": []
|
135 |
+
})
|
136 |
+
|
137 |
+
for chunk_idx in range(1, 2 * len(chunks)):
|
138 |
+
if chunk_idx % 2 == 1:
|
139 |
+
chunk = chunks[chunk_idx // 2].strip()
|
140 |
+
converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
|
141 |
+
else:
|
142 |
+
if modal == 'image':
|
143 |
+
converted_messages[-1]["content"].append({"type": "image"})
|
144 |
+
elif modal == 'video':
|
145 |
+
converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
|
146 |
+
|
147 |
+
messages = converted_messages
|
148 |
+
|
149 |
+
# 2. vision preprocess (load & transform image or video).
|
150 |
+
if model.config.model_type in ['videollama3_mistral', 'videollama3_mixtral']:
|
151 |
+
system_message = [
|
152 |
+
{'role': 'system', 'content': (
|
153 |
+
"""<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature."""
|
154 |
+
"""\n"""
|
155 |
+
"""If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>""")
|
156 |
+
}
|
157 |
+
]
|
158 |
+
else:
|
159 |
+
system_message = []
|
160 |
+
|
161 |
+
image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size)
|
162 |
+
# TODO: attention mask?
|
163 |
+
messages = system_message + messages
|
164 |
+
data_dict = vlprocessor(
|
165 |
+
images=images,
|
166 |
+
text=messages,
|
167 |
+
image_downsampling=image_downsampling,
|
168 |
+
return_tensors="pt",
|
169 |
+
)
|
170 |
+
|
171 |
+
torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16
|
172 |
+
|
173 |
+
images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]]
|
174 |
+
grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]]
|
175 |
+
|
176 |
+
# 3. generate response according to visual signals and prompts.
|
177 |
+
keywords = [tokenizer.eos_token]
|
178 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"])
|
179 |
+
stop_str = tokenizer.eos_token
|
180 |
+
|
181 |
+
do_sample = kwargs.get('do_sample', False)
|
182 |
+
temperature = kwargs.get('temperature', 0.2 if do_sample else 0.0)
|
183 |
+
top_p = kwargs.get('top_p', 0.9)
|
184 |
+
max_new_tokens = kwargs.get('max_new_tokens', 2048)
|
185 |
+
if not streaming:
|
186 |
+
with torch.inference_mode():
|
187 |
+
output_ids = model.generate(
|
188 |
+
# input_ids,
|
189 |
+
# attention_mask=attention_masks,
|
190 |
+
# images=images,
|
191 |
+
data_dict["input_ids"].cuda(),
|
192 |
+
attention_mask=data_dict["attention_mask"].cuda(),
|
193 |
+
images=[(modal, images, grid_thws)],
|
194 |
+
do_sample=do_sample,
|
195 |
+
temperature=temperature,
|
196 |
+
max_new_tokens=max_new_tokens,
|
197 |
+
top_p=top_p,
|
198 |
+
use_cache=True,
|
199 |
+
stopping_criteria=[stopping_criteria],
|
200 |
+
pad_token_id=tokenizer.eos_token_id,
|
201 |
+
additional_images=[additional_images],
|
202 |
+
masks=[masks],
|
203 |
+
)
|
204 |
+
|
205 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
206 |
+
|
207 |
+
yield outputs
|
208 |
+
|
209 |
+
else:
|
210 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
211 |
+
generation_kwargs = dict(
|
212 |
+
inputs=data_dict["input_ids"].cuda(),
|
213 |
+
attention_mask=data_dict["attention_mask"].cuda(),
|
214 |
+
images=[(modal, images, grid_thws)],
|
215 |
+
do_sample=do_sample,
|
216 |
+
temperature=temperature,
|
217 |
+
max_new_tokens=max_new_tokens,
|
218 |
+
top_p=top_p,
|
219 |
+
use_cache=True,
|
220 |
+
stopping_criteria=[stopping_criteria],
|
221 |
+
pad_token_id=tokenizer.eos_token_id,
|
222 |
+
additional_images=[additional_images],
|
223 |
+
masks=[masks],
|
224 |
+
streamer=streamer
|
225 |
+
)
|
226 |
+
|
227 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
228 |
+
thread.start()
|
229 |
+
|
230 |
+
generated_text = ""
|
231 |
+
for new_text in streamer:
|
232 |
+
generated_text += new_text
|
233 |
+
if stop_str in generated_text:
|
234 |
+
generated_text = generated_text[:generated_text.find(stop_str)]
|
235 |
+
break
|
236 |
+
yield new_text
|
237 |
+
|
238 |
+
thread.join()
|
239 |
+
|
videollama3/constants.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
|
9 |
+
# Image arguments
|
10 |
+
IMAGE_TOKEN_INDEX = -200
|
11 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
12 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
13 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
14 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
15 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
16 |
+
|
17 |
+
# Video arguments
|
18 |
+
VIDEO_TOKEN_INDEX = -201
|
19 |
+
DEFAULT_VIDEO_TOKEN = "<video>"
|
20 |
+
NUM_FRAMES = 128
|
21 |
+
MAX_FRAMES = 768
|
22 |
+
NUM_FRAMES_PER_SECOND = 1
|
23 |
+
|
24 |
+
# Region arguments
|
25 |
+
REGION_TOKEN = "<REGION>"
|
26 |
+
|
27 |
+
# Audio arguments
|
28 |
+
AUDIO_TOKEN_INDEX = -202
|
29 |
+
DEFAULT_AUDIO_TOKEN = "<audio>"
|
30 |
+
|
31 |
+
# Stream arguments
|
32 |
+
STREAM_START_TOKEN = "<|stream_start|>"
|
33 |
+
STREAM_END_TOKEN = "<|stream_end|>"
|
34 |
+
STREAM_IMAGE_TOKEN = "<stream_image>"
|
35 |
+
STREAM_FPS = 2
|
36 |
+
STREAM_IMAGE_SIZE = 224
|
37 |
+
STREAM_DOWNSAMPLING = 4
|
38 |
+
STREAM_MAX_FRAMES = 400
|
39 |
+
|
40 |
+
MODAL_INDEX_MAP = {
|
41 |
+
"<image>": -200,
|
42 |
+
"<video>": -201,
|
43 |
+
"<audio>": -202,
|
44 |
+
}
|
45 |
+
|
46 |
+
subimage_token_num=196
|
videollama3/infer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import sys
|
7 |
+
sys.path.append('./')
|
8 |
+
from videollama3 import disable_torch_init, model_init, mm_infer, get_model_output
|
9 |
+
from videollama3.mm_utils import load_video
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
def infer_image(model, tokenizer):
|
15 |
+
image_path = 'demo/images/1.jpg'
|
16 |
+
image = Image.open(image_path)
|
17 |
+
image_data = np.array(image)
|
18 |
+
|
19 |
+
question = '<image>\nPlease describe the <region> in the image in detail.'
|
20 |
+
|
21 |
+
mask = np.load('demo/masks/demo0.npy')
|
22 |
+
masks = []
|
23 |
+
masks.append(mask)
|
24 |
+
masks = np.array(masks)
|
25 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
26 |
+
|
27 |
+
mask_ids = [0]*len(masks)
|
28 |
+
|
29 |
+
output = get_model_output(
|
30 |
+
[image_data],
|
31 |
+
question,
|
32 |
+
model=model,
|
33 |
+
tokenizer=tokenizer,
|
34 |
+
masks=masks,
|
35 |
+
mask_ids=mask_ids,
|
36 |
+
modal='image',
|
37 |
+
image_downsampling=1,
|
38 |
+
)
|
39 |
+
print(output)
|
40 |
+
|
41 |
+
def infer_video(model, tokenizer):
|
42 |
+
video_path = 'demo/videos/1.mp4'
|
43 |
+
question = '<video>\nPlease describe the <region> in the video in detail.'
|
44 |
+
|
45 |
+
frame_idx = 0 # mask from the first frame
|
46 |
+
video_tensor = load_video(video_path, fps=1, max_frames=768, frame_ids=[frame_idx])
|
47 |
+
|
48 |
+
mask = np.load('demo/masks/demo1.npy')
|
49 |
+
masks = []
|
50 |
+
masks.append(mask)
|
51 |
+
masks = np.array(masks)
|
52 |
+
masks = torch.from_numpy(masks).to(torch.uint8)
|
53 |
+
|
54 |
+
mask_ids = [0]*len(masks)
|
55 |
+
|
56 |
+
output = get_model_output(
|
57 |
+
video_tensor,
|
58 |
+
question,
|
59 |
+
model=model,
|
60 |
+
tokenizer=tokenizer,
|
61 |
+
masks=masks,
|
62 |
+
mask_ids=mask_ids,
|
63 |
+
modal='video',
|
64 |
+
)
|
65 |
+
print(output)
|
66 |
+
|
67 |
+
def main():
|
68 |
+
disable_torch_init()
|
69 |
+
|
70 |
+
# fill in the model path here
|
71 |
+
model_path = '/mnt/workspace/workgroup/yuanyq/code/videollama3/ProjectX_region/work_dirs/VideoRefer-VideoLLaMA3-7B'
|
72 |
+
model, processor, tokenizer = model_init(model_path)
|
73 |
+
|
74 |
+
# image
|
75 |
+
infer_image(model, tokenizer)
|
76 |
+
|
77 |
+
# viideo
|
78 |
+
infer_video(model, tokenizer)
|
79 |
+
|
80 |
+
|
81 |
+
if __name__=='__main__':
|
82 |
+
main()
|
videollama3/mm_utils.py
ADDED
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import math
|
5 |
+
import base64
|
6 |
+
import traceback
|
7 |
+
from io import BytesIO
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchvision.transforms.functional as VF
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import numpy as np
|
14 |
+
from transformers import StoppingCriteria
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import imageio
|
18 |
+
import ffmpeg
|
19 |
+
from PIL import Image
|
20 |
+
from decord import VideoReader, cpu
|
21 |
+
|
22 |
+
from .constants import NUM_FRAMES, MAX_FRAMES, NUM_FRAMES_PER_SECOND, MODAL_INDEX_MAP, DEFAULT_IMAGE_TOKEN
|
23 |
+
from pycocotools import mask as maskUtils
|
24 |
+
|
25 |
+
def resize_image_mask(images, masks, mask_ids, patch_size=14):
|
26 |
+
resize_images = []
|
27 |
+
resize_masks = []
|
28 |
+
mask_nums = []
|
29 |
+
for i, mask in enumerate(masks):
|
30 |
+
image = images[mask_ids[i]]
|
31 |
+
h, w = image.shape[:2]
|
32 |
+
if mask.sum()==0:
|
33 |
+
print('mask is none...')
|
34 |
+
mask = torch.ones((h, w))
|
35 |
+
rows, cols = np.where(mask == 1)
|
36 |
+
|
37 |
+
min_row, max_row = rows.min(), rows.max()
|
38 |
+
min_col, max_col = cols.min(), cols.max()
|
39 |
+
|
40 |
+
bbox = (max(0,min_row-patch_size*2), max(0,min_col-patch_size*2), min(h-1, max_row+patch_size*2), min(w-1, max_col+patch_size*2))
|
41 |
+
mask_h = bbox[2] - bbox[0]
|
42 |
+
mask_w = bbox[3] - bbox[1]
|
43 |
+
cropping_img = image[bbox[0]: bbox[2], bbox[1]: bbox[3], :]
|
44 |
+
cropping_mask = mask[bbox[0]: bbox[2], bbox[1]: bbox[3]]
|
45 |
+
|
46 |
+
scale_rate = math.ceil(math.sqrt(1960/mask.sum()))
|
47 |
+
if scale_rate==1:
|
48 |
+
if (mask.sum()/196)>100:
|
49 |
+
scale_rate = math.sqrt((mask.sum()/196)/100)
|
50 |
+
scale_rate = 1/scale_rate
|
51 |
+
resize_h = math.ceil((mask_h*scale_rate)/patch_size) * patch_size
|
52 |
+
resize_w = math.ceil((mask_w*scale_rate)/patch_size) * patch_size
|
53 |
+
|
54 |
+
resize_img = cv2.resize(cropping_img, (resize_w, resize_h))
|
55 |
+
resize_mask = F.interpolate(cropping_mask[None, None], size=(resize_h//patch_size, resize_w//patch_size), mode='bilinear', align_corners=False)[0,0]
|
56 |
+
mask_nums.append(min(10, int(resize_mask.sum())))
|
57 |
+
|
58 |
+
resize_images.append(resize_img)
|
59 |
+
resize_masks.append(resize_mask)
|
60 |
+
|
61 |
+
return resize_images, resize_masks, mask_nums
|
62 |
+
|
63 |
+
def reshape_images_to_raw_grid(mm_features_raw, grid_thws):
|
64 |
+
start_idx=0
|
65 |
+
reshaped_features = []
|
66 |
+
for thw_group in grid_thws:
|
67 |
+
for tensor_thw in thw_group:
|
68 |
+
_, H, W = tensor_thw.squeeze().tolist()
|
69 |
+
num_elements = H * W
|
70 |
+
|
71 |
+
split_tensor = mm_features_raw[start_idx:start_idx + num_elements].view(H, W, -1)
|
72 |
+
reshaped_features.append(split_tensor)
|
73 |
+
|
74 |
+
start_idx += num_elements
|
75 |
+
assert len(mm_features_raw)==start_idx
|
76 |
+
return reshaped_features
|
77 |
+
|
78 |
+
def annToMask(mask_ann, h=None, w=None):
|
79 |
+
if isinstance(mask_ann, list):
|
80 |
+
rles = maskUtils.frPyObjects(mask_ann, h, w)
|
81 |
+
rle = maskUtils.merge(rles)
|
82 |
+
elif isinstance(mask_ann['counts'], list):
|
83 |
+
# uncompressed RLE
|
84 |
+
rle = maskUtils.frPyObjects(mask_ann, h, w)
|
85 |
+
else:
|
86 |
+
# rle
|
87 |
+
rle = mask_ann
|
88 |
+
mask = maskUtils.decode(rle)
|
89 |
+
return mask
|
90 |
+
|
91 |
+
def chunk_list(input_list, chunk_size):
|
92 |
+
return [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
|
93 |
+
|
94 |
+
|
95 |
+
def load_image_from_base64(image):
|
96 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
97 |
+
|
98 |
+
|
99 |
+
def expand2square(pil_img, background_color):
|
100 |
+
width, height = pil_img.size
|
101 |
+
if width == height:
|
102 |
+
return pil_img
|
103 |
+
elif width > height:
|
104 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
105 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
106 |
+
return result
|
107 |
+
else:
|
108 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
109 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
110 |
+
return result
|
111 |
+
|
112 |
+
|
113 |
+
def grid_divide(image, cell_size):
|
114 |
+
"""
|
115 |
+
Divides an image into grid of a specified size.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
image (PIL.Image.Image): The input image.
|
119 |
+
cell_size (int): The size of each cell.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
123 |
+
"""
|
124 |
+
grid = []
|
125 |
+
width, height = image.size
|
126 |
+
for i in range(0, height, cell_size):
|
127 |
+
row = []
|
128 |
+
for j in range(0, width, cell_size):
|
129 |
+
box = (j, i, j + cell_size, i + cell_size)
|
130 |
+
row.append(image.crop(box))
|
131 |
+
grid.append(row)
|
132 |
+
|
133 |
+
return grid
|
134 |
+
|
135 |
+
|
136 |
+
def load_images(image_path):
|
137 |
+
if isinstance(image_path, str) and os.path.isfile(image_path):
|
138 |
+
images = [cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)]
|
139 |
+
# images = [Image.open(image_path).convert('RGB')]
|
140 |
+
elif isinstance(image_path, str) and os.path.isdir(image_path):
|
141 |
+
images = [cv2.cvtColor(cv2.imread(os.path.join(image_path, f)), cv2.COLOR_BGR2RGB) for f in sorted(os.listdir(image_path))]
|
142 |
+
# images = [Image.open(os.path.join(image_path, f)).convert('RGB') for f in sorted(os.listdir(image_path))]
|
143 |
+
elif isinstance(image_path, list) and isinstance(image_path[0], str):
|
144 |
+
images = [cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB) for f in image_path]
|
145 |
+
# images = [Image.open(f).convert('RGB') for f in image_path]
|
146 |
+
elif isinstance(image_path, list) and isinstance(image_path[0], Image.Image):
|
147 |
+
images = image_path
|
148 |
+
elif isinstance(image_path, Image.Image):
|
149 |
+
images = [image_path]
|
150 |
+
else:
|
151 |
+
print('image_path: ', image_path)
|
152 |
+
raise ValueError(f"Unsupported image path type: {image_path}")
|
153 |
+
|
154 |
+
return images
|
155 |
+
|
156 |
+
|
157 |
+
def process_pad_image(image, padding_value=(0, 0, 0)):
|
158 |
+
image = expand2square(image, padding_value)
|
159 |
+
|
160 |
+
return [image]
|
161 |
+
|
162 |
+
|
163 |
+
def find_closest_aspect_ratio(src_ratio, tgt_ratios, ori_size, tgt_size):
|
164 |
+
best_ratio_diff = float('inf')
|
165 |
+
best_ratio = (1, 1)
|
166 |
+
area = ori_size[0] * ori_size[1]
|
167 |
+
for ratio in tgt_ratios:
|
168 |
+
tgt_ratio = ratio[0] / ratio[1]
|
169 |
+
ratio_diff = abs(src_ratio - tgt_ratio)
|
170 |
+
if ratio_diff < best_ratio_diff:
|
171 |
+
best_ratio_diff = ratio_diff
|
172 |
+
best_ratio = ratio
|
173 |
+
elif ratio_diff == best_ratio_diff:
|
174 |
+
if area > 0.5 * tgt_size[0] * tgt_size[1] * ratio[0] * ratio[1]:
|
175 |
+
best_ratio = ratio
|
176 |
+
|
177 |
+
return best_ratio
|
178 |
+
|
179 |
+
|
180 |
+
def process_dynamic_image(image, image_size=384, use_thumbnail=True):
|
181 |
+
# Grid Params:
|
182 |
+
min_num = 1
|
183 |
+
max_num = 12
|
184 |
+
|
185 |
+
if isinstance(image_size, int):
|
186 |
+
image_size = (image_size, image_size)
|
187 |
+
|
188 |
+
ori_size = image.size
|
189 |
+
aspect_ratio = ori_size[0] / ori_size[1]
|
190 |
+
|
191 |
+
# calculate the existing image aspect ratio
|
192 |
+
tgt_ratios = []
|
193 |
+
for n in range(min_num, max_num + 1):
|
194 |
+
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
|
195 |
+
tgt_ratios = set(tgt_ratios)
|
196 |
+
tgt_ratios = sorted(tgt_ratios, key=lambda x: x[0] * x[1])
|
197 |
+
|
198 |
+
# find the closest aspect ratio to the target
|
199 |
+
tgt_ratio = find_closest_aspect_ratio(aspect_ratio, tgt_ratios, ori_size, image_size)
|
200 |
+
|
201 |
+
# resize the image to the target size
|
202 |
+
tgt_width = image_size[0] * tgt_ratio[0]
|
203 |
+
tgt_height = image_size[1] * tgt_ratio[1]
|
204 |
+
resized_img = image.resize((tgt_width, tgt_height))
|
205 |
+
|
206 |
+
# NOTE: internvl2 style split the image into one column grids
|
207 |
+
# num_grids = tgt_ratio[0] * tgt_ratio[1]
|
208 |
+
# grid_images = []
|
209 |
+
# for i in range(num_grids):
|
210 |
+
# box = (
|
211 |
+
# (i % tgt_ratio[0]) * image_size[0],
|
212 |
+
# (i // tgt_ratio[0]) * image_size[1],
|
213 |
+
# (i % tgt_ratio[0] + 1) * image_size[0],
|
214 |
+
# (i // tgt_ratio[0] + 1) * image_size[1],
|
215 |
+
# )
|
216 |
+
# # crop out the grid image
|
217 |
+
# grid_images.append(resized_img.crop(box))
|
218 |
+
# assert len(grid_images) == num_grids
|
219 |
+
# grid_images = [grid_images]
|
220 |
+
|
221 |
+
# NOTE: eager implementation
|
222 |
+
# num_grids = tgt_ratio[0] * tgt_ratio[1]
|
223 |
+
# sub_grid_images = []
|
224 |
+
# tmp_grid_images = []
|
225 |
+
# for i in range(num_grids):
|
226 |
+
# box = (
|
227 |
+
# (i % tgt_ratio[0]) * image_size[0],
|
228 |
+
# (i // tgt_ratio[0]) * image_size[1],
|
229 |
+
# (i % tgt_ratio[0] + 1) * image_size[0],
|
230 |
+
# (i // tgt_ratio[0] + 1) * image_size[1],
|
231 |
+
# )
|
232 |
+
# tmp_grid_images.append(resized_img.crop(box))
|
233 |
+
|
234 |
+
# if (i + 1) % tgt_ratio[0] == 0:
|
235 |
+
# sub_grid_images.append(tmp_grid_images)
|
236 |
+
# tmp_grid_images = []
|
237 |
+
|
238 |
+
image_grid = grid_divide(resized_img, image_size[0])
|
239 |
+
|
240 |
+
if use_thumbnail:
|
241 |
+
thumbnail_img = image.resize((image_size[0], image_size[1]))
|
242 |
+
image_grid = [[thumbnail_img]] + image_grid
|
243 |
+
|
244 |
+
return image_grid
|
245 |
+
|
246 |
+
|
247 |
+
def process_highres_image(image_path, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
|
248 |
+
# Grid Params:
|
249 |
+
grid_width = [1, 2, 3]
|
250 |
+
grid_width_real = [x * image_size for x in grid_width]
|
251 |
+
|
252 |
+
longest_side = max(image.size)
|
253 |
+
fit_grid_width_real = [x for x in grid_width_real if x >= longest_side]
|
254 |
+
if len(fit_grid_width_real) == 0:
|
255 |
+
select_size = max(grid_width_real)
|
256 |
+
else:
|
257 |
+
select_size = min(fit_grid_width_real)
|
258 |
+
|
259 |
+
image_padded = expand2square(image, padding_value)
|
260 |
+
image_padded = image_padded.resize((select_size, select_size))
|
261 |
+
image_grid = grid_divide(image_padded, image_size)
|
262 |
+
|
263 |
+
if use_thumbnail:
|
264 |
+
thumbnail_img = image.resize((image_size, image_size))
|
265 |
+
image_grid = [[thumbnail_img]] + image_grid
|
266 |
+
|
267 |
+
return image_grid
|
268 |
+
|
269 |
+
|
270 |
+
def select_best_resolution(original_size, possible_resolutions):
|
271 |
+
"""
|
272 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
276 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
tuple: The best fit resolution in the format (width, height).
|
280 |
+
"""
|
281 |
+
original_width, original_height = original_size
|
282 |
+
best_fit = None
|
283 |
+
max_effective_resolution = 0
|
284 |
+
min_wasted_resolution = float('inf')
|
285 |
+
|
286 |
+
for width, height in possible_resolutions:
|
287 |
+
scale = min(width / original_width, height / original_height)
|
288 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
289 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
290 |
+
wasted_resolution = (width * height) - effective_resolution
|
291 |
+
|
292 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
293 |
+
max_effective_resolution = effective_resolution
|
294 |
+
min_wasted_resolution = wasted_resolution
|
295 |
+
best_fit = (width, height)
|
296 |
+
|
297 |
+
return best_fit
|
298 |
+
|
299 |
+
|
300 |
+
def process_anyres_image(image, image_size=384, use_thumbnail=True, padding_value=(0, 0, 0)):
|
301 |
+
"""
|
302 |
+
Process an image with variable resolutions.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
image (PIL.Image.Image): The input image to be processed.
|
306 |
+
processor: The image processor object.
|
307 |
+
|
308 |
+
Returns:
|
309 |
+
torch.Tensor: A tensor containing the processed image patches.
|
310 |
+
"""
|
311 |
+
# Grid Params:
|
312 |
+
possible_grids = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3)]
|
313 |
+
possible_resolutions = [(x * image_size, y * image_size) for x, y in possible_grids]
|
314 |
+
|
315 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
316 |
+
|
317 |
+
# resize and padding image
|
318 |
+
nw, nh = best_resolution
|
319 |
+
ow, oh = image.size
|
320 |
+
|
321 |
+
scale_factor = min(nw / ow, nh / oh)
|
322 |
+
new_size = (int(ow * scale_factor), int(oh * scale_factor))
|
323 |
+
|
324 |
+
image_padded = Image.new("RGB", (nw, nh), padding_value)
|
325 |
+
image_padded.paste(image.resize(new_size), ((nw - new_size[0]) // 2, (nh - new_size[1]) // 2))
|
326 |
+
|
327 |
+
image_grid = grid_divide(image_padded, image_size)
|
328 |
+
|
329 |
+
if use_thumbnail:
|
330 |
+
thumbnail_img = image.resize((image_size, image_size))
|
331 |
+
image_grid = [[thumbnail_img]] + image_grid
|
332 |
+
|
333 |
+
return image_grid
|
334 |
+
|
335 |
+
|
336 |
+
def process_adares_image(image_path, image_size=384, use_thumbnail=True):
|
337 |
+
# Grid Params:
|
338 |
+
min_num = 1
|
339 |
+
max_num = 12
|
340 |
+
|
341 |
+
if isinstance(image_size, int):
|
342 |
+
image_size = (image_size, image_size)
|
343 |
+
|
344 |
+
ori_size = image.size
|
345 |
+
aspect_ratio = ori_size[0] / ori_size[1]
|
346 |
+
|
347 |
+
# calculate the existing image aspect ratio
|
348 |
+
tgt_ratios = []
|
349 |
+
for n in range(min_num, max_num + 1):
|
350 |
+
tgt_ratios.extend([(i, j) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num])
|
351 |
+
tgt_ratios = set(tgt_ratios)
|
352 |
+
possible_resolutions = [(x * image_size[0], y * image_size[1]) for x, y in tgt_ratios]
|
353 |
+
|
354 |
+
# find the most possible resolution
|
355 |
+
best_resolution = select_best_resolution(ori_size, possible_resolutions)
|
356 |
+
|
357 |
+
# resize the image to the target size
|
358 |
+
resized_img = image.resize((best_resolution[0], best_resolution[1]))
|
359 |
+
|
360 |
+
image_grid = grid_divide(resized_img, image_size[0])
|
361 |
+
|
362 |
+
if use_thumbnail:
|
363 |
+
thumbnail_img = image.resize((image_size[0], image_size[1]))
|
364 |
+
image_grid = [[thumbnail_img]] + image_grid
|
365 |
+
|
366 |
+
return image_grid
|
367 |
+
|
368 |
+
|
369 |
+
def process_images(image_path, processor, aspect_ratio='pad', image_size=384, use_thumbnail=True):
|
370 |
+
images = load_images(image_path)
|
371 |
+
|
372 |
+
padding_value = tuple(int(x*255) for x in processor.image_mean)
|
373 |
+
|
374 |
+
image_grids = []
|
375 |
+
for image in images:
|
376 |
+
if aspect_ratio == 'pad':
|
377 |
+
image_grid = process_pad_image(image, padding_value=padding_value)
|
378 |
+
elif aspect_ratio == 'dynamic':
|
379 |
+
image_grid = process_dynamic_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
|
380 |
+
elif aspect_ratio == 'highres':
|
381 |
+
image_grid = process_highres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
|
382 |
+
elif aspect_ratio == 'anyres':
|
383 |
+
image_grid = process_anyres_image(image, image_size=image_size, use_thumbnail=use_thumbnail, padding_value=padding_value)
|
384 |
+
elif aspect_ratio == 'adares':
|
385 |
+
image_grid = process_adares_image(image, image_size=image_size, use_thumbnail=use_thumbnail)
|
386 |
+
else:
|
387 |
+
image_grid = [image]
|
388 |
+
|
389 |
+
image_grid = [processor.preprocess(image_row, return_tensors='pt', num_images=len(images)) for image_row in image_grid]
|
390 |
+
image_grids.append(image_grid)
|
391 |
+
|
392 |
+
return image_grids
|
393 |
+
|
394 |
+
|
395 |
+
def frame_sample(duration, mode='uniform', num_frames=None, vid_fps=None, fps=None):
|
396 |
+
if mode == 'uniform':
|
397 |
+
assert num_frames is not None, "Number of frames must be provided for uniform sampling."
|
398 |
+
if duration <= num_frames:
|
399 |
+
return np.arange(duration).astype(int)
|
400 |
+
# NOTE: v1 version
|
401 |
+
# Calculate the size of each segment from which a frame will be extracted
|
402 |
+
# if duration <= num_frames:
|
403 |
+
# return np.arange(duration).astype(int)
|
404 |
+
# seg_size = float(duration - 1) / num_frames
|
405 |
+
|
406 |
+
# frame_ids = []
|
407 |
+
# for i in range(num_frames):
|
408 |
+
# # Calculate the start and end indices of each segment
|
409 |
+
# start = seg_size * i
|
410 |
+
# end = seg_size * (i + 1)
|
411 |
+
# # Append the middle index of the segment to the list
|
412 |
+
# frame_ids.append((start + end) / 2)
|
413 |
+
|
414 |
+
# return np.round(np.array(frame_ids) + 1e-6).astype(int)
|
415 |
+
# NOTE: v0 version
|
416 |
+
return np.linspace(0, duration-1, num_frames, dtype=int)
|
417 |
+
elif mode == 'fps':
|
418 |
+
assert vid_fps is not None, "FPS must be provided for FPS sampling."
|
419 |
+
fps = fps if fps is not None else NUM_FRAMES_PER_SECOND
|
420 |
+
segment_len = min(vid_fps // fps, duration)
|
421 |
+
return np.arange(segment_len // 2, duration, segment_len, dtype=int)
|
422 |
+
else:
|
423 |
+
raise ImportError(f'Unsupported frame sampling mode: {mode}')
|
424 |
+
|
425 |
+
|
426 |
+
def load_video_from_ids(video_path, s=None, e=None, fps=None, max_frames=None, temporal_factor=1, frame_ids=None):
|
427 |
+
if s is not None and e is not None:
|
428 |
+
s = s if s >= 0. else 0.
|
429 |
+
e = e if e >= 0. else 0.
|
430 |
+
if s > e:
|
431 |
+
s, e = e, s
|
432 |
+
elif s == e:
|
433 |
+
e = s + 1
|
434 |
+
|
435 |
+
# 1. Loading Video
|
436 |
+
if os.path.isdir(video_path):
|
437 |
+
frame_files = sorted(os.listdir(video_path))
|
438 |
+
|
439 |
+
vid_fps = 3
|
440 |
+
num_frames_of_video = len(frame_files)
|
441 |
+
elif video_path.endswith('.gif'):
|
442 |
+
gif_reader = imageio.get_reader(video_path)
|
443 |
+
|
444 |
+
vid_fps = 25
|
445 |
+
num_frames_of_video = len(gif_reader)
|
446 |
+
else:
|
447 |
+
vreader = VideoReader(video_path, ctx=cpu(0), num_threads=2)
|
448 |
+
# vreader = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
449 |
+
|
450 |
+
vid_fps = vreader.get_avg_fps()
|
451 |
+
num_frames_of_video = len(vreader)
|
452 |
+
|
453 |
+
# 2. Determine frame range & Calculate frame indices
|
454 |
+
f_start = 0 if s is None else max(int(s * vid_fps) - 1, 0)
|
455 |
+
f_end = num_frames_of_video - 1 if e is None else min(int(e * vid_fps) - 1, num_frames_of_video - 1)
|
456 |
+
frame_indices = list(range(f_start, f_end + 1))
|
457 |
+
|
458 |
+
duration = len(frame_indices)
|
459 |
+
# 3. Sampling frame indices
|
460 |
+
max_frames = max_frames if max_frames is not None else MAX_FRAMES
|
461 |
+
if fps is not None and duration / vid_fps < max_frames:
|
462 |
+
try:
|
463 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='fps', vid_fps=vid_fps, fps=fps)]
|
464 |
+
except:
|
465 |
+
print('sampled_frame_indices error: ', )
|
466 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
|
467 |
+
|
468 |
+
else:
|
469 |
+
sampled_frame_indices = [frame_indices[i] for i in frame_sample(duration, mode='uniform', num_frames=max_frames)]
|
470 |
+
|
471 |
+
# 4. Acquire frame data
|
472 |
+
if os.path.isdir(video_path):
|
473 |
+
frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in sampled_frame_indices]
|
474 |
+
elif video_path.endswith('.gif'):
|
475 |
+
frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in sampled_frame_indices]
|
476 |
+
else:
|
477 |
+
frames = vreader.get_batch(sampled_frame_indices).asnumpy()
|
478 |
+
|
479 |
+
# frames = frames.transpose(0, 3, 1, 2)
|
480 |
+
timestamps = [x / vid_fps for x in sampled_frame_indices]
|
481 |
+
|
482 |
+
if temporal_factor > 1:
|
483 |
+
pad_length = temporal_factor - len(frames) % temporal_factor
|
484 |
+
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
|
485 |
+
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
|
486 |
+
|
487 |
+
# NOTE: pad the video with black frames
|
488 |
+
# while num_frames is not None and len(video_data) < num_frames:
|
489 |
+
# video_data.append(Image.fromarray(np.zeros((*video_data[-1].size, 3), dtype=np.uint8)))
|
490 |
+
|
491 |
+
additional_frames = []
|
492 |
+
if frame_ids is not None:
|
493 |
+
if os.path.isdir(video_path):
|
494 |
+
additional_frames = [cv2.cvtColor(cv2.imread(os.path.join(video_path, frame_files[frame_idx])), cv2.COLOR_BGR2RGB) for frame_idx in frame_ids]
|
495 |
+
elif video_path.endswith('.gif'):
|
496 |
+
additional_frames = [cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) for idx, frame in enumerate(gif_reader) if idx in frame_ids]
|
497 |
+
else:
|
498 |
+
additional_frames = vreader.get_batch(frame_ids).asnumpy()
|
499 |
+
|
500 |
+
return frames, timestamps, additional_frames
|
501 |
+
|
502 |
+
|
503 |
+
def load_video(
|
504 |
+
video_path: str,
|
505 |
+
start_time: Optional[float] = None,
|
506 |
+
end_time: Optional[float] = None,
|
507 |
+
fps: Optional[float] = None,
|
508 |
+
max_frames: Optional[float] = None,
|
509 |
+
size: Optional[int] = None,
|
510 |
+
size_divisible: int = 1,
|
511 |
+
precise_time: bool = False,
|
512 |
+
verbose: bool = False,
|
513 |
+
temporal_factor: int = 1,
|
514 |
+
frame_ids = None
|
515 |
+
):
|
516 |
+
"""
|
517 |
+
Load and process a video file and return the frames and the timestamps of each frame.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
video_path (str): Path to the video file.
|
521 |
+
start_time (float, optional): Start time in seconds. Defaults to None.
|
522 |
+
end_time (float, optional): End time in seconds. Defaults to None.
|
523 |
+
fps (float, optional): Frames per second. Defaults to None.
|
524 |
+
num_frames (float, optional): Number of frames to sample. Defaults to None.
|
525 |
+
size (int, optional): Size of the shortest side. Defaults to None.
|
526 |
+
size_divisible (int, optional): Size divisible by this number. Defaults to 1.
|
527 |
+
precise_time (bool, optional): Whether to use precise time. Defaults to False.
|
528 |
+
verbose (bool, optional): Print ffmpeg output. Defaults to False.
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
frames (List[PIL.Image]): List of frames.
|
532 |
+
timestamps (List[float]): List of timestamps.
|
533 |
+
"""
|
534 |
+
if start_time is not None and end_time is not None and end_time - start_time < 1:
|
535 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
|
536 |
+
if os.path.isdir(video_path):
|
537 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
|
538 |
+
if video_path.endswith('.gif'):
|
539 |
+
return load_video_from_ids(video_path, start_time, end_time, fps=fps, max_frames=max_frames, frame_ids=frame_ids)
|
540 |
+
probe = ffmpeg.probe(video_path)
|
541 |
+
duration = float(probe['format']['duration'])
|
542 |
+
video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
|
543 |
+
w, h = int(video_stream['width']), int(video_stream['height'])
|
544 |
+
|
545 |
+
kwargs, input_kwargs, output_kwargs = {}, {}, {}
|
546 |
+
do_trim = start_time is not None or end_time is not None
|
547 |
+
if start_time is not None:
|
548 |
+
new_start_time = max(float(video_stream['start_time']), start_time)
|
549 |
+
duration -= new_start_time - start_time
|
550 |
+
start_time = new_start_time
|
551 |
+
else:
|
552 |
+
start_time = float(video_stream['start_time'])
|
553 |
+
if end_time is not None:
|
554 |
+
duration = min(duration, end_time - start_time)
|
555 |
+
else:
|
556 |
+
duration = duration
|
557 |
+
if do_trim:
|
558 |
+
kwargs = {'ss': start_time, 't': duration}
|
559 |
+
if precise_time:
|
560 |
+
output_kwargs.update(kwargs)
|
561 |
+
else:
|
562 |
+
input_kwargs.update(kwargs)
|
563 |
+
|
564 |
+
if size is not None:
|
565 |
+
scale_factor = size / min(w, h)
|
566 |
+
new_w, new_h = round(w * scale_factor), round(h * scale_factor)
|
567 |
+
else:
|
568 |
+
new_w, new_h = w, h
|
569 |
+
new_w = new_w // size_divisible * size_divisible
|
570 |
+
new_h = new_h // size_divisible * size_divisible
|
571 |
+
|
572 |
+
# NOTE: It may result in unexpected number of frames in ffmpeg
|
573 |
+
# if calculate the fps directly according to max_frames
|
574 |
+
# NOTE: the below lines may hurt the performance
|
575 |
+
# if max_frames is not None and (fps is None or duration * fps > 2 * max_frames):
|
576 |
+
# fps = max_frames / duration * 2
|
577 |
+
|
578 |
+
stream = ffmpeg.input(video_path, **input_kwargs)
|
579 |
+
if fps is not None:
|
580 |
+
stream = ffmpeg.filter(stream, "fps", fps=fps, round="down")
|
581 |
+
if new_w != w or new_h != h:
|
582 |
+
stream = ffmpeg.filter(stream, 'scale', new_w, new_h)
|
583 |
+
stream = ffmpeg.output(stream, "pipe:", format="rawvideo", pix_fmt="rgb24", **output_kwargs)
|
584 |
+
out, _ = ffmpeg.run(stream, capture_stdout=True, quiet=not verbose)
|
585 |
+
|
586 |
+
frames = np.frombuffer(out, np.uint8).reshape([-1, new_h, new_w, 3]).transpose([0, 3, 1, 2])
|
587 |
+
|
588 |
+
if fps is not None:
|
589 |
+
timestamps = np.arange(start_time, start_time + duration + 1 / fps, 1 / fps)[:len(frames)]
|
590 |
+
else:
|
591 |
+
timestamps = np.linspace(start_time, start_time + duration, len(frames))
|
592 |
+
|
593 |
+
max_frames = max_frames if max_frames is not None else MAX_FRAMES
|
594 |
+
if max_frames is not None and len(frames) > max_frames:
|
595 |
+
indices = np.linspace(0, len(frames) - 1, max_frames, dtype=int)
|
596 |
+
frames = frames[indices]
|
597 |
+
timestamps = [timestamps[i] for i in indices]
|
598 |
+
|
599 |
+
if temporal_factor > 1:
|
600 |
+
pad_length = temporal_factor - len(frames) % temporal_factor
|
601 |
+
frames = np.concatenate([frames, frames[-1:].repeat(pad_length, axis=0)])
|
602 |
+
[timestamps.append(timestamps[-1] + 1 / fps) for _ in range(pad_length)]
|
603 |
+
|
604 |
+
frames = [frame for frame in frames]
|
605 |
+
additional_frames = []
|
606 |
+
# print('frame_ids', frame_ids)
|
607 |
+
if frame_ids is not None:
|
608 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
609 |
+
additional_frames = vr.get_batch(frame_ids).asnumpy()
|
610 |
+
|
611 |
+
return frames, timestamps, additional_frames
|
612 |
+
|
613 |
+
|
614 |
+
def process_video(video_path, processor, s=None, e=None, aspect_ratio='pad', num_frames=None):
|
615 |
+
fps = 1 if num_frames is None else None
|
616 |
+
# FFmpeg
|
617 |
+
frames, timestamps = load_video(video_path, s, e, fps=fps, max_frames=num_frames)
|
618 |
+
# Decord
|
619 |
+
# frames, timestamps = load_video_from_ids(video_path, s, e, fps=fps, max_frames=num_frames)
|
620 |
+
|
621 |
+
assert len(frames) == len(timestamps), "Number of frames and timestamps must match."
|
622 |
+
|
623 |
+
if aspect_ratio == 'pad':
|
624 |
+
frames = [expand2square(f, tuple(int(x*255) for x in processor.image_mean)) for f in frames]
|
625 |
+
|
626 |
+
if aspect_ratio == 'qwen2vl':
|
627 |
+
frames = [processor.preprocess(frame, return_tensors='pt', image_num=len(frames)) for frame in frames]
|
628 |
+
grid_frames = [frames]
|
629 |
+
else:
|
630 |
+
frames = processor.preprocess(frames, return_tensors='pt', image_num=len(frames))
|
631 |
+
grid_frames = [[frames]]
|
632 |
+
|
633 |
+
return grid_frames, timestamps
|
634 |
+
|
635 |
+
|
636 |
+
def tokenizer_multimodal_token(prompt, tokenizer, multimodal_token=DEFAULT_IMAGE_TOKEN, return_tensors=None):
|
637 |
+
"""Tokenize text and multimodal tag to input_ids.
|
638 |
+
|
639 |
+
Args:
|
640 |
+
prompt (str): Text prompt (w/ multimodal tag), e.g., '<video>\nDescribe the video.'
|
641 |
+
tokenizer (transformers.PreTrainedTokenizer): Tokenizer object.
|
642 |
+
multimodal_token (int): Token index corresponding to the multimodal tag.
|
643 |
+
"""
|
644 |
+
multimodal_token_index = MODAL_INDEX_MAP.get(multimodal_token, None)
|
645 |
+
if multimodal_token_index is None:
|
646 |
+
input_ids = tokenizer(prompt, add_special_tokens=False).input_ids
|
647 |
+
else:
|
648 |
+
prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for idx, chunk in enumerate(prompt.split(multimodal_token))]
|
649 |
+
|
650 |
+
input_ids = []
|
651 |
+
for i in range(1, 2 * len(prompt_chunks)):
|
652 |
+
if i % 2 == 1:
|
653 |
+
input_ids.extend(prompt_chunks[i // 2])
|
654 |
+
else:
|
655 |
+
input_ids.append(multimodal_token_index)
|
656 |
+
|
657 |
+
if return_tensors is not None:
|
658 |
+
if return_tensors == 'pt':
|
659 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
660 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
661 |
+
return input_ids
|
662 |
+
|
663 |
+
|
664 |
+
def get_model_name_from_path(model_path):
|
665 |
+
model_path = model_path.strip("/")
|
666 |
+
model_paths = model_path.split("/")
|
667 |
+
if model_paths[-1].startswith('checkpoint-'):
|
668 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
669 |
+
else:
|
670 |
+
return model_paths[-1]
|
671 |
+
|
672 |
+
|
673 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
674 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
675 |
+
self.keywords = keywords
|
676 |
+
self.keyword_ids = []
|
677 |
+
self.max_keyword_len = 0
|
678 |
+
for keyword in keywords:
|
679 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
680 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
681 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
682 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
683 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
684 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
685 |
+
self.tokenizer = tokenizer
|
686 |
+
self.start_len = input_ids.shape[1]
|
687 |
+
|
688 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
689 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
690 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
691 |
+
for keyword_id in self.keyword_ids:
|
692 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
693 |
+
return True
|
694 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
695 |
+
for keyword in self.keywords:
|
696 |
+
if keyword in outputs:
|
697 |
+
return True
|
698 |
+
return False
|
699 |
+
|
700 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
701 |
+
outputs = []
|
702 |
+
for i in range(output_ids.shape[0]):
|
703 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
704 |
+
return all(outputs)
|
videollama3/model/__init__.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
2 |
+
# Copyright 2023 Haotian Liu
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import warnings
|
19 |
+
import shutil
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from transformers import PretrainedConfig, AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
|
23 |
+
|
24 |
+
from .projector import load_mm_projector
|
25 |
+
from .videollama3_qwen2 import Videollama3Qwen2ForCausalLM, Videollama3Qwen2Config
|
26 |
+
|
27 |
+
|
28 |
+
VLLMs = {
|
29 |
+
"videollama3_qwen2": Videollama3Qwen2ForCausalLM,
|
30 |
+
}
|
31 |
+
|
32 |
+
VLLMConfigs = {
|
33 |
+
"videollama3_qwen2": Videollama3Qwen2Config,
|
34 |
+
}
|
35 |
+
|
36 |
+
|
37 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", **kwargs):
|
38 |
+
if 'token' in kwargs:
|
39 |
+
token = kwargs['token']
|
40 |
+
else:
|
41 |
+
token = None
|
42 |
+
|
43 |
+
# NOTE: auto device_map by default
|
44 |
+
# if want to put model into a single device, you can set device_map={"": "cuda:0"}
|
45 |
+
kwargs = {"device_map": device_map, **kwargs}
|
46 |
+
|
47 |
+
config = AutoConfig.from_pretrained(model_path)
|
48 |
+
config._attn_implementation = kwargs.pop('attn_implementation', "flash_attention_2") # default to flash_attention_2
|
49 |
+
|
50 |
+
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else kwargs.pop('torch_dtype', torch.float16)
|
51 |
+
|
52 |
+
if load_8bit:
|
53 |
+
kwargs['load_in_8bit'] = True
|
54 |
+
elif load_4bit:
|
55 |
+
# NOTE: High-version Transformers will report: """ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time."""
|
56 |
+
# kwargs['load_in_4bit'] = True
|
57 |
+
kwargs['quantization_config'] = BitsAndBytesConfig(
|
58 |
+
load_in_4bit=True,
|
59 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
60 |
+
bnb_4bit_use_double_quant=True,
|
61 |
+
bnb_4bit_quant_type='nf4'
|
62 |
+
)
|
63 |
+
else:
|
64 |
+
kwargs['torch_dtype'] = torch_dtype
|
65 |
+
|
66 |
+
# judge model type
|
67 |
+
model_type = config.model_type if hasattr(config, "model_type") else kwargs.pop('model_type', "videollama3_qwen2")
|
68 |
+
|
69 |
+
# judge pretrain/finetune
|
70 |
+
is_alignment = getattr(config, "tune_mm_mlp_adapter", False) or getattr(config, "is_alignment", False)
|
71 |
+
|
72 |
+
# NOTE: lora/qlora model loading
|
73 |
+
if 'lora' in model_name.lower() or 'qlora' in model_name.lower():
|
74 |
+
cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
|
75 |
+
# NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
|
76 |
+
# cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
|
77 |
+
model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
|
78 |
+
|
79 |
+
# NOTE: remove qlora training quantization config
|
80 |
+
if hasattr(lora_cfg_pretrained, 'quantization_config'):
|
81 |
+
del lora_cfg_pretrained.quantization_config
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
|
83 |
+
print('Loading VideoLLaMA from base model...')
|
84 |
+
|
85 |
+
if 'qwen2' in model_base.lower():
|
86 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
87 |
+
else:
|
88 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
89 |
+
|
90 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
91 |
+
if model.lm_head.weight.shape[0] != token_num:
|
92 |
+
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
93 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
|
94 |
+
|
95 |
+
print('Loading additional VideoLLaMA weights...')
|
96 |
+
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
|
97 |
+
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
|
98 |
+
else:
|
99 |
+
# this is probably from HF Hub
|
100 |
+
from huggingface_hub import hf_hub_download
|
101 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
102 |
+
cache_file = hf_hub_download(
|
103 |
+
repo_id=repo_id,
|
104 |
+
filename=filename,
|
105 |
+
subfolder=subfolder)
|
106 |
+
return torch.load(cache_file, map_location='cpu')
|
107 |
+
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
|
108 |
+
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
|
109 |
+
if any(k.startswith('model.model.') for k in non_lora_trainables):
|
110 |
+
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
|
111 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
112 |
+
|
113 |
+
from peft import PeftModel
|
114 |
+
print('Loading LoRA weights...')
|
115 |
+
model = PeftModel.from_pretrained(model, model_path)
|
116 |
+
print('Merging LoRA weights...')
|
117 |
+
model = model.merge_and_unload()
|
118 |
+
print('Model is loaded...')
|
119 |
+
elif model_base is not None or '-base' in model_name.lower() or is_alignment:
|
120 |
+
# NOTE: Base/Pretrain model loading
|
121 |
+
print('Loading VideoLLaMA 2 from base model...')
|
122 |
+
cfg_pretrained = PretrainedConfig.from_pretrained(model_path, token=token)
|
123 |
+
# NOTE: AutoConfig will modify `_name_or_path` property to `model_path` if `model_path` is not None.
|
124 |
+
# cfg_pretrained = AutoConfig.from_pretrained(model_path, token=token)
|
125 |
+
model_base = model_base if model_base is not None else cfg_pretrained._name_or_path
|
126 |
+
|
127 |
+
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, token=token)
|
128 |
+
|
129 |
+
if model_type in ['videollama3', 'videollama3_qwen2']:
|
130 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
131 |
+
else:
|
132 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=config, **kwargs)
|
133 |
+
|
134 |
+
# NOTE; loading vision-language projector
|
135 |
+
# * old codes for loading local mm_projector.bin
|
136 |
+
# mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
|
137 |
+
# mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
138 |
+
# model.load_state_dict(mm_projector_weights, strict=False)
|
139 |
+
# * new codes which supports loading mm_projector.bin both offline and online
|
140 |
+
mm_projector_weights = load_mm_projector(model_path, token=token)
|
141 |
+
model.load_state_dict(mm_projector_weights, strict=False)
|
142 |
+
elif 'videollama' in model_type:
|
143 |
+
# NOTE: SFT model loading
|
144 |
+
print(model_path)
|
145 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, token=token)
|
146 |
+
|
147 |
+
if model_type in ['videollama3_qwen2']:
|
148 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
|
149 |
+
else:
|
150 |
+
model = Videollama3Qwen2ForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=config, **kwargs)
|
151 |
+
else:
|
152 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, token=token)
|
153 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, config=config, **kwargs)
|
154 |
+
|
155 |
+
processor = None
|
156 |
+
|
157 |
+
if "videollama" in model_type:
|
158 |
+
vision_encoder = model.get_vision_encoder()
|
159 |
+
processor = vision_encoder.image_processor
|
160 |
+
|
161 |
+
if hasattr(model.config, "max_sequence_length"):
|
162 |
+
context_len = model.config.max_sequence_length
|
163 |
+
else:
|
164 |
+
context_len = 2048
|
165 |
+
|
166 |
+
return tokenizer, model, processor, context_len
|
videollama3/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (4.04 kB). View file
|
|
videollama3/model/__pycache__/encoder.cpython-310.pyc
ADDED
Binary file (10.5 kB). View file
|
|
videollama3/model/__pycache__/processor.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
videollama3/model/__pycache__/projector.cpython-310.pyc
ADDED
Binary file (5.11 kB). View file
|
|
videollama3/model/__pycache__/region_encoder.cpython-310.pyc
ADDED
Binary file (3.43 kB). View file
|
|
videollama3/model/__pycache__/videollama3_arch.cpython-310.pyc
ADDED
Binary file (9.74 kB). View file
|
|
videollama3/model/__pycache__/videollama3_qwen2.cpython-310.pyc
ADDED
Binary file (4.2 kB). View file
|
|
videollama3/model/damovl_encoder/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration_damovl_encoder import DAMOVLVisionConfig
|
2 |
+
from .image_processing import DAMOVLImageProcessor
|
3 |
+
from .modeling_damovl_encoder import DAMOVLVisionModel
|
videollama3/model/damovl_encoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (407 Bytes). View file
|
|
videollama3/model/damovl_encoder/__pycache__/configuration_damovl_encoder.cpython-310.pyc
ADDED
Binary file (1.96 kB). View file
|
|
videollama3/model/damovl_encoder/__pycache__/image_processing.cpython-310.pyc
ADDED
Binary file (16.7 kB). View file
|
|
videollama3/model/damovl_encoder/__pycache__/modeling_damovl_encoder.cpython-310.pyc
ADDED
Binary file (16.9 kB). View file
|
|
videollama3/model/damovl_encoder/configuration_damovl_encoder.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Qwen2VL model configuration"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class DAMOVLVisionConfig(PretrainedConfig):
|
28 |
+
model_type = "damovl"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
hidden_size=768,
|
33 |
+
intermediate_size=3072,
|
34 |
+
num_hidden_layers=12,
|
35 |
+
num_attention_heads=12,
|
36 |
+
num_channels=3,
|
37 |
+
patch_size=16,
|
38 |
+
hidden_act="gelu_pytorch_tanh",
|
39 |
+
layer_norm_eps=1e-6,
|
40 |
+
attention_dropout=0.0,
|
41 |
+
spatial_merge_size=1,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
|
46 |
+
self.hidden_size = hidden_size
|
47 |
+
self.intermediate_size = intermediate_size
|
48 |
+
self.num_hidden_layers = num_hidden_layers
|
49 |
+
self.num_attention_heads = num_attention_heads
|
50 |
+
self.num_channels = num_channels
|
51 |
+
self.patch_size = patch_size
|
52 |
+
self.attention_dropout = attention_dropout
|
53 |
+
self.layer_norm_eps = layer_norm_eps
|
54 |
+
self.hidden_act = hidden_act
|
55 |
+
self.spatial_merge_size = spatial_merge_size
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
59 |
+
cls._set_token_in_kwargs(kwargs)
|
60 |
+
|
61 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
62 |
+
|
63 |
+
# config_dict = config_dict["vision_config"]
|
64 |
+
|
65 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
66 |
+
logger.warning(
|
67 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
68 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
69 |
+
)
|
70 |
+
|
71 |
+
return cls.from_dict(config_dict, **kwargs)
|
videollama3/model/damovl_encoder/image_processing.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""Image processor class for Qwen2-VL."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from typing import Dict, List, Optional, Union
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
28 |
+
from transformers.image_transforms import (
|
29 |
+
convert_to_rgb,
|
30 |
+
resize,
|
31 |
+
to_channel_dimension_format,
|
32 |
+
)
|
33 |
+
from transformers.image_utils import (
|
34 |
+
OPENAI_CLIP_MEAN,
|
35 |
+
OPENAI_CLIP_STD,
|
36 |
+
ChannelDimension,
|
37 |
+
ImageInput,
|
38 |
+
PILImageResampling,
|
39 |
+
VideoInput,
|
40 |
+
get_image_size,
|
41 |
+
infer_channel_dimension_format,
|
42 |
+
is_scaled_image,
|
43 |
+
is_valid_image,
|
44 |
+
make_list_of_images,
|
45 |
+
to_numpy_array,
|
46 |
+
valid_images,
|
47 |
+
validate_preprocess_arguments,
|
48 |
+
)
|
49 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
50 |
+
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
if is_vision_available():
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
|
59 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
60 |
+
"""
|
61 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
65 |
+
The input image.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
list: A list of images.
|
69 |
+
"""
|
70 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
71 |
+
return [img for img_list in images for img in img_list]
|
72 |
+
|
73 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
74 |
+
return images
|
75 |
+
|
76 |
+
elif is_valid_image(images):
|
77 |
+
return [images]
|
78 |
+
|
79 |
+
raise ValueError(f"Could not make batched images from {images}")
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
83 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
84 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
85 |
+
return videos
|
86 |
+
|
87 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
88 |
+
if isinstance(videos[0], Image.Image):
|
89 |
+
return [videos]
|
90 |
+
elif len(videos[0].shape) == 4:
|
91 |
+
return [list(video) for video in videos]
|
92 |
+
|
93 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
94 |
+
return [list(videos)]
|
95 |
+
|
96 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
97 |
+
|
98 |
+
|
99 |
+
def smart_resize(
|
100 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
101 |
+
):
|
102 |
+
"""Rescales the image so that the following conditions are met:
|
103 |
+
|
104 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
105 |
+
|
106 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
107 |
+
|
108 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
109 |
+
|
110 |
+
"""
|
111 |
+
if height < factor or width < factor:
|
112 |
+
scale = factor / min(height, width)
|
113 |
+
width = round(scale * width)
|
114 |
+
height = round(scale * height)
|
115 |
+
elif max(height, width) / min(height, width) > 200:
|
116 |
+
raise ValueError(
|
117 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
118 |
+
)
|
119 |
+
h_bar = round(height / factor) * factor
|
120 |
+
w_bar = round(width / factor) * factor
|
121 |
+
if h_bar * w_bar > max_pixels:
|
122 |
+
beta = math.sqrt((height * width) / max_pixels)
|
123 |
+
h_bar = math.floor(height / beta / factor) * factor
|
124 |
+
w_bar = math.floor(width / beta / factor) * factor
|
125 |
+
elif h_bar * w_bar < min_pixels:
|
126 |
+
beta = math.sqrt(min_pixels / (height * width))
|
127 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
128 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
129 |
+
return h_bar, w_bar
|
130 |
+
|
131 |
+
|
132 |
+
class DAMOVLImageProcessor(BaseImageProcessor):
|
133 |
+
r"""
|
134 |
+
Constructs a DAMOVL image processor that dynamically resizes images based on the original images.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
138 |
+
Whether to resize the image's (height, width) dimensions.
|
139 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
140 |
+
Resampling filter to use when resizing the image.
|
141 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
143 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
144 |
+
Scale factor to use if rescaling the image.
|
145 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
146 |
+
Whether to normalize the image.
|
147 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
148 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
149 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
150 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
151 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
152 |
+
Whether to convert the image to RGB.
|
153 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
154 |
+
The min pixels of the image to resize the image.
|
155 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
156 |
+
The max pixels of the image to resize the image.
|
157 |
+
patch_size (`int`, *optional*, defaults to 14):
|
158 |
+
The spacial patch size of the vision encoder.
|
159 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
160 |
+
The temporal patch size of the vision encoder.
|
161 |
+
merge_size (`int`, *optional*, defaults to 2):
|
162 |
+
The merge size of the vision encoder to llm encoder.
|
163 |
+
"""
|
164 |
+
|
165 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
do_resize: bool = True,
|
170 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
171 |
+
do_rescale: bool = True,
|
172 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
173 |
+
do_normalize: bool = True,
|
174 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
175 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
176 |
+
do_convert_rgb: bool = True,
|
177 |
+
min_pixels: int = 56 * 56,
|
178 |
+
max_pixels: int = 14 * 14 * 9477,
|
179 |
+
patch_size: int = 14,
|
180 |
+
merge_size: int = 1,
|
181 |
+
**kwargs,
|
182 |
+
) -> None:
|
183 |
+
super().__init__(**kwargs)
|
184 |
+
self.do_resize = do_resize
|
185 |
+
self.resample = resample
|
186 |
+
self.do_rescale = do_rescale
|
187 |
+
self.rescale_factor = rescale_factor
|
188 |
+
self.do_normalize = do_normalize
|
189 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
190 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
191 |
+
self.min_pixels = min_pixels
|
192 |
+
self.max_pixels = max_pixels
|
193 |
+
self.patch_size = patch_size
|
194 |
+
self.merge_size = merge_size
|
195 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
196 |
+
self.do_convert_rgb = do_convert_rgb
|
197 |
+
|
198 |
+
self.temporal_patch_size = 1
|
199 |
+
|
200 |
+
def _preprocess(
|
201 |
+
self,
|
202 |
+
images: Union[ImageInput, VideoInput],
|
203 |
+
do_resize: bool = None,
|
204 |
+
resample: PILImageResampling = None,
|
205 |
+
do_rescale: bool = None,
|
206 |
+
rescale_factor: float = None,
|
207 |
+
do_normalize: bool = None,
|
208 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
209 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
210 |
+
do_convert_rgb: bool = None,
|
211 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
212 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
213 |
+
num_images: Optional[int] = 1,
|
214 |
+
image_downsampling: Optional[int] = None,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
images (`ImageInput`):
|
221 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
222 |
+
vision_info (`List[Dict]`, *optional*):
|
223 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
224 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
225 |
+
Whether to resize the image.
|
226 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
227 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
228 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
229 |
+
Whether to rescale the image.
|
230 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
231 |
+
Scale factor to use if rescaling the image.
|
232 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
233 |
+
Whether to normalize the image.
|
234 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
235 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
236 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
237 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
238 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
239 |
+
Whether to convert the image to RGB.
|
240 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
241 |
+
The channel dimension format for the output image. Can be one of:
|
242 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
243 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
244 |
+
- Unset: Use the channel dimension format of the input image.
|
245 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
246 |
+
The channel dimension format for the input image. Can be one of:
|
247 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
248 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
249 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
250 |
+
"""
|
251 |
+
images = make_list_of_images(images)
|
252 |
+
|
253 |
+
if do_convert_rgb:
|
254 |
+
images = [convert_to_rgb(image) for image in images]
|
255 |
+
|
256 |
+
# All transformations expect numpy arrays.
|
257 |
+
images = [to_numpy_array(image) for image in images]
|
258 |
+
|
259 |
+
if is_scaled_image(images[0]) and do_rescale:
|
260 |
+
logger.warning_once(
|
261 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
262 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
263 |
+
)
|
264 |
+
if input_data_format is None:
|
265 |
+
# We assume that all images have the same channel dimension format.
|
266 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
267 |
+
|
268 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
269 |
+
resized_height, resized_width = height, width
|
270 |
+
processed_images = []
|
271 |
+
for image in images:
|
272 |
+
if do_resize:
|
273 |
+
max_pixels = int(self.max_pixels / (self.merge_size / image_downsampling)**2)
|
274 |
+
resized_height, resized_width = smart_resize(
|
275 |
+
height,
|
276 |
+
width,
|
277 |
+
factor=self.patch_size * image_downsampling,
|
278 |
+
min_pixels=self.min_pixels,
|
279 |
+
max_pixels=int(max_pixels // num_images),
|
280 |
+
)
|
281 |
+
image = resize(
|
282 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
283 |
+
)
|
284 |
+
|
285 |
+
if do_rescale:
|
286 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
287 |
+
|
288 |
+
if do_normalize:
|
289 |
+
image = self.normalize(
|
290 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
291 |
+
)
|
292 |
+
|
293 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
294 |
+
processed_images.append(image)
|
295 |
+
|
296 |
+
patches = np.array(processed_images)
|
297 |
+
if data_format == ChannelDimension.LAST:
|
298 |
+
patches = patches.transpose(0, 3, 1, 2)
|
299 |
+
|
300 |
+
channel = patches.shape[1]
|
301 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
302 |
+
patches = patches.reshape(
|
303 |
+
channel,
|
304 |
+
grid_h // image_downsampling,
|
305 |
+
image_downsampling,
|
306 |
+
self.patch_size,
|
307 |
+
grid_w // image_downsampling,
|
308 |
+
image_downsampling,
|
309 |
+
self.patch_size,
|
310 |
+
)
|
311 |
+
patches = patches.transpose(1, 4, 2, 5, 0, 3, 6)
|
312 |
+
flatten_patches = patches.reshape(
|
313 |
+
grid_h * grid_w, channel * self.patch_size * self.patch_size
|
314 |
+
)
|
315 |
+
# print('image_downsampling', image_downsampling)
|
316 |
+
# flatten_patches1 = flatten_patches.reshape(grid_h, grid_w, channel, -1)
|
317 |
+
# from matplotlib import pyplot as plt
|
318 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
319 |
+
# plt.savefig('8.png')
|
320 |
+
|
321 |
+
return flatten_patches, (1, grid_h, grid_w)
|
322 |
+
|
323 |
+
def preprocess(
|
324 |
+
self,
|
325 |
+
images: ImageInput,
|
326 |
+
videos: VideoInput = None,
|
327 |
+
do_resize: bool = None,
|
328 |
+
size: Dict[str, int] = None,
|
329 |
+
resample: PILImageResampling = None,
|
330 |
+
do_rescale: bool = None,
|
331 |
+
rescale_factor: float = None,
|
332 |
+
do_normalize: bool = None,
|
333 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
334 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
335 |
+
do_convert_rgb: bool = None,
|
336 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
337 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
338 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
339 |
+
num_images: Optional[int] = 1,
|
340 |
+
image_downsampling: Optional[int] = None,
|
341 |
+
):
|
342 |
+
"""
|
343 |
+
Args:
|
344 |
+
images (`ImageInput`):
|
345 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
346 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
347 |
+
videos (`VideoInput`):
|
348 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
349 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
350 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
351 |
+
Whether to resize the image.
|
352 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
353 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
354 |
+
the longest edge resized to keep the input aspect ratio.
|
355 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
356 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
357 |
+
has an effect if `do_resize` is set to `True`.
|
358 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
359 |
+
Whether to rescale the image.
|
360 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
361 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
362 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
363 |
+
Whether to normalize the image.
|
364 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
365 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
366 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
367 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
368 |
+
`True`.
|
369 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
370 |
+
Whether to convert the image to RGB.
|
371 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
372 |
+
The type of tensors to return. Can be one of:
|
373 |
+
- Unset: Return a list of `np.ndarray`.
|
374 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
375 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
376 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
377 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
378 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
379 |
+
The channel dimension format for the output image. Can be one of:
|
380 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
381 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
382 |
+
- Unset: Use the channel dimension format of the input image.
|
383 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
384 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
385 |
+
from the input image. Can be one of:
|
386 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
387 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
388 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
389 |
+
|
390 |
+
"""
|
391 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
392 |
+
size = size if size is not None else self.size
|
393 |
+
resample = resample if resample is not None else self.resample
|
394 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
395 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
396 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
397 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
398 |
+
image_std = image_std if image_std is not None else self.image_std
|
399 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
400 |
+
image_downsampling = image_downsampling if image_downsampling is not None else self.merge_size
|
401 |
+
|
402 |
+
if images is not None:
|
403 |
+
images = make_batched_images(images)
|
404 |
+
if videos is not None:
|
405 |
+
videos = make_batched_videos(videos)
|
406 |
+
|
407 |
+
if images is not None and not valid_images(images):
|
408 |
+
raise ValueError(
|
409 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
410 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
411 |
+
)
|
412 |
+
|
413 |
+
validate_preprocess_arguments(
|
414 |
+
rescale_factor=rescale_factor,
|
415 |
+
do_normalize=do_normalize,
|
416 |
+
image_mean=image_mean,
|
417 |
+
image_std=image_std,
|
418 |
+
do_resize=do_resize,
|
419 |
+
size=size,
|
420 |
+
resample=resample,
|
421 |
+
)
|
422 |
+
|
423 |
+
if images is not None:
|
424 |
+
pixel_values, vision_grid_thws = [], []
|
425 |
+
for image in images:
|
426 |
+
patches, image_grid_thw = self._preprocess(
|
427 |
+
image,
|
428 |
+
do_resize=do_resize,
|
429 |
+
resample=resample,
|
430 |
+
do_rescale=do_rescale,
|
431 |
+
rescale_factor=rescale_factor,
|
432 |
+
do_normalize=do_normalize,
|
433 |
+
image_mean=image_mean,
|
434 |
+
image_std=image_std,
|
435 |
+
data_format=data_format,
|
436 |
+
do_convert_rgb=do_convert_rgb,
|
437 |
+
input_data_format=input_data_format,
|
438 |
+
num_images=num_images,
|
439 |
+
image_downsampling=image_downsampling,
|
440 |
+
)
|
441 |
+
pixel_values.extend(patches)
|
442 |
+
vision_grid_thws.append(image_grid_thw)
|
443 |
+
pixel_values = np.array(pixel_values)
|
444 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
445 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
446 |
+
|
447 |
+
assert videos is None, "Not support video for now."
|
448 |
+
# NOTE: not support video for now
|
449 |
+
# if videos is not None:
|
450 |
+
# pixel_values, vision_grid_thws = [], []
|
451 |
+
# for images in videos:
|
452 |
+
# patches, video_grid_thw = self._preprocess(
|
453 |
+
# images,
|
454 |
+
# do_resize=do_resize,
|
455 |
+
# resample=resample,
|
456 |
+
# do_rescale=do_rescale,
|
457 |
+
# rescale_factor=rescale_factor,
|
458 |
+
# do_normalize=do_normalize,
|
459 |
+
# image_mean=image_mean,
|
460 |
+
# image_std=image_std,
|
461 |
+
# data_format=data_format,
|
462 |
+
# do_convert_rgb=do_convert_rgb,
|
463 |
+
# input_data_format=input_data_format,
|
464 |
+
# image_num=image_num,
|
465 |
+
# )
|
466 |
+
# pixel_values.extend(patches)
|
467 |
+
# vision_grid_thws.append(video_grid_thw)
|
468 |
+
# pixel_values = np.array(pixel_values)
|
469 |
+
# vision_grid_thws = np.array(vision_grid_thws)
|
470 |
+
# data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
471 |
+
|
472 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
videollama3/model/damovl_encoder/modeling_damovl_encoder.py
ADDED
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch Siglip model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import warnings
|
19 |
+
from dataclasses import dataclass
|
20 |
+
from typing import Any, Optional, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
27 |
+
from torch.nn.init import _calculate_fan_in_and_fan_out
|
28 |
+
import torch.nn.functional as F
|
29 |
+
|
30 |
+
from transformers.activations import ACT2FN
|
31 |
+
from transformers.modeling_utils import PreTrainedModel
|
32 |
+
from transformers.utils import (add_start_docstrings,
|
33 |
+
add_start_docstrings_to_model_forward,
|
34 |
+
is_flash_attn_2_available,
|
35 |
+
is_flash_attn_greater_or_equal_2_10, logging,
|
36 |
+
replace_return_docstrings)
|
37 |
+
from .configuration_damovl_encoder import DAMOVLVisionConfig
|
38 |
+
|
39 |
+
|
40 |
+
if is_flash_attn_2_available():
|
41 |
+
from flash_attn import flash_attn_varlen_func
|
42 |
+
from transformers.modeling_flash_attention_utils import \
|
43 |
+
_flash_attention_forward
|
44 |
+
else:
|
45 |
+
flash_attn_varlen_func = None
|
46 |
+
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
def _trunc_normal_(tensor, mean, std, a, b):
|
52 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
53 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
54 |
+
def norm_cdf(x):
|
55 |
+
# Computes standard normal cumulative distribution function
|
56 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
57 |
+
|
58 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
59 |
+
warnings.warn(
|
60 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
61 |
+
"The distribution of values may be incorrect.",
|
62 |
+
stacklevel=2,
|
63 |
+
)
|
64 |
+
|
65 |
+
# Values are generated by using a truncated uniform distribution and
|
66 |
+
# then using the inverse CDF for the normal distribution.
|
67 |
+
# Get upper and lower cdf values
|
68 |
+
l = norm_cdf((a - mean) / std)
|
69 |
+
u = norm_cdf((b - mean) / std)
|
70 |
+
|
71 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
72 |
+
# [2l-1, 2u-1].
|
73 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
74 |
+
|
75 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
76 |
+
# standard normal
|
77 |
+
tensor.erfinv_()
|
78 |
+
|
79 |
+
# Transform to proper mean, std
|
80 |
+
tensor.mul_(std * math.sqrt(2.0))
|
81 |
+
tensor.add_(mean)
|
82 |
+
|
83 |
+
# Clamp to ensure it's in the proper range
|
84 |
+
tensor.clamp_(min=a, max=b)
|
85 |
+
|
86 |
+
|
87 |
+
def trunc_normal_tf_(
|
88 |
+
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
89 |
+
) -> torch.Tensor:
|
90 |
+
"""Fills the input Tensor with values drawn from a truncated
|
91 |
+
normal distribution. The values are effectively drawn from the
|
92 |
+
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
93 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
94 |
+
the bounds. The method used for generating the random values works
|
95 |
+
best when :math:`a \\leq \text{mean} \\leq b`.
|
96 |
+
|
97 |
+
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
98 |
+
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
99 |
+
and the result is subsequently scaled and shifted by the mean and std args.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
tensor: an n-dimensional `torch.Tensor`
|
103 |
+
mean: the mean of the normal distribution
|
104 |
+
std: the standard deviation of the normal distribution
|
105 |
+
a: the minimum cutoff value
|
106 |
+
b: the maximum cutoff value
|
107 |
+
"""
|
108 |
+
with torch.no_grad():
|
109 |
+
_trunc_normal_(tensor, 0, 1.0, a, b)
|
110 |
+
tensor.mul_(std).add_(mean)
|
111 |
+
|
112 |
+
|
113 |
+
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
114 |
+
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
115 |
+
if mode == "fan_in":
|
116 |
+
denom = fan_in
|
117 |
+
elif mode == "fan_out":
|
118 |
+
denom = fan_out
|
119 |
+
elif mode == "fan_avg":
|
120 |
+
denom = (fan_in + fan_out) / 2
|
121 |
+
|
122 |
+
variance = scale / denom
|
123 |
+
|
124 |
+
if distribution == "truncated_normal":
|
125 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
126 |
+
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
127 |
+
elif distribution == "normal":
|
128 |
+
with torch.no_grad():
|
129 |
+
tensor.normal_(std=math.sqrt(variance))
|
130 |
+
elif distribution == "uniform":
|
131 |
+
bound = math.sqrt(3 * variance)
|
132 |
+
with torch.no_grad():
|
133 |
+
tensor.uniform_(-bound, bound)
|
134 |
+
else:
|
135 |
+
raise ValueError(f"invalid distribution {distribution}")
|
136 |
+
|
137 |
+
|
138 |
+
def lecun_normal_(tensor):
|
139 |
+
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
140 |
+
|
141 |
+
|
142 |
+
def default_flax_embed_init(tensor):
|
143 |
+
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
144 |
+
|
145 |
+
|
146 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
147 |
+
def rotate_half(x):
|
148 |
+
"""Rotates half the hidden dims of the input."""
|
149 |
+
x1 = x[..., : x.shape[-1] // 2]
|
150 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
151 |
+
return torch.cat((-x2, x1), dim=-1)
|
152 |
+
|
153 |
+
|
154 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
155 |
+
orig_dtype = tensor.dtype
|
156 |
+
tensor = tensor.float()
|
157 |
+
cos = freqs.cos()
|
158 |
+
sin = freqs.sin()
|
159 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
160 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
161 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
162 |
+
output = output.to(orig_dtype)
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
class VisionRotaryEmbedding(nn.Module):
|
167 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
168 |
+
super().__init__()
|
169 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
170 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
171 |
+
|
172 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
173 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
174 |
+
freqs = torch.outer(seq, self.inv_freq)
|
175 |
+
return freqs
|
176 |
+
|
177 |
+
|
178 |
+
class DAMOVLVisionEmbeddings(nn.Module):
|
179 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
180 |
+
super().__init__()
|
181 |
+
self.config = config
|
182 |
+
self.embed_dim = config.hidden_size
|
183 |
+
self.patch_size = config.patch_size
|
184 |
+
|
185 |
+
self.patch_embedding = nn.Conv2d(
|
186 |
+
in_channels=config.num_channels,
|
187 |
+
out_channels=self.embed_dim,
|
188 |
+
kernel_size=self.patch_size,
|
189 |
+
stride=self.patch_size,
|
190 |
+
padding="valid",
|
191 |
+
)
|
192 |
+
|
193 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
194 |
+
hidden_states = hidden_states.view(
|
195 |
+
-1, self.config.num_channels, self.patch_size, self.patch_size
|
196 |
+
)
|
197 |
+
patch_embeds = self.patch_embedding(hidden_states) # shape = [*, width, grid, grid]
|
198 |
+
# embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
199 |
+
embeddings = patch_embeds.view(-1, self.embed_dim)
|
200 |
+
|
201 |
+
return embeddings
|
202 |
+
|
203 |
+
|
204 |
+
class VisionAttention(nn.Module):
|
205 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
206 |
+
|
207 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
208 |
+
def __init__(self, config):
|
209 |
+
super().__init__()
|
210 |
+
self.config = config
|
211 |
+
self.embed_dim = config.hidden_size
|
212 |
+
self.num_heads = config.num_attention_heads
|
213 |
+
self.head_dim = self.embed_dim // self.num_heads
|
214 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
215 |
+
raise ValueError(
|
216 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
217 |
+
f" {self.num_heads})."
|
218 |
+
)
|
219 |
+
self.scale = self.head_dim**-0.5
|
220 |
+
self.dropout = config.attention_dropout
|
221 |
+
|
222 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
223 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
224 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
225 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
226 |
+
|
227 |
+
def forward(
|
228 |
+
self,
|
229 |
+
hidden_states: torch.Tensor,
|
230 |
+
cu_seqlens: torch.Tensor,
|
231 |
+
rotary_pos_emb: torch.Tensor = None,
|
232 |
+
) -> torch.Tensor:
|
233 |
+
"""Input shape: Time x Channel"""
|
234 |
+
|
235 |
+
q_len, _ = hidden_states.size()
|
236 |
+
|
237 |
+
query_states = self.q_proj(hidden_states)
|
238 |
+
key_states = self.k_proj(hidden_states)
|
239 |
+
value_states = self.v_proj(hidden_states)
|
240 |
+
|
241 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
242 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
243 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
244 |
+
|
245 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
246 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
247 |
+
|
248 |
+
attention_mask = torch.zeros([1, q_len, q_len], device=q.device, dtype=torch.bool)
|
249 |
+
for i in range(1, len(cu_seqlens)):
|
250 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
251 |
+
|
252 |
+
query_states = query_states.transpose(0, 1)
|
253 |
+
key_states = key_states.transpose(0, 1)
|
254 |
+
value_states = value_states.transpose(0, 1)
|
255 |
+
|
256 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(1, 2)) / math.sqrt(self.head_dim)
|
257 |
+
attn_weights = attn_weights + attention_mask
|
258 |
+
|
259 |
+
# upcast attention to fp32
|
260 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
261 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
262 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
263 |
+
|
264 |
+
attn_output = attn_output.transpose(0, 1)
|
265 |
+
attn_output = attn_output.reshape(q_len, -1)
|
266 |
+
attn_output = self.out_proj(attn_output)
|
267 |
+
|
268 |
+
return attn_output
|
269 |
+
|
270 |
+
|
271 |
+
class VisionFlashAttention2(VisionAttention):
|
272 |
+
def __init__(self, *args, **kwargs):
|
273 |
+
super().__init__(*args, **kwargs)
|
274 |
+
|
275 |
+
# Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
hidden_states: torch.Tensor,
|
279 |
+
cu_seqlens: torch.Tensor,
|
280 |
+
rotary_pos_emb: torch.Tensor = None,
|
281 |
+
) -> torch.Tensor:
|
282 |
+
q_len, _ = hidden_states.size()
|
283 |
+
|
284 |
+
query_states = self.q_proj(hidden_states)
|
285 |
+
key_states = self.k_proj(hidden_states)
|
286 |
+
value_states = self.v_proj(hidden_states)
|
287 |
+
|
288 |
+
# Flash attention requires the input to have the shape
|
289 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
290 |
+
# therefore we just need to keep the original shape
|
291 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
292 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
293 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
294 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
295 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
296 |
+
|
297 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
298 |
+
attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
299 |
+
q_len, -1
|
300 |
+
)
|
301 |
+
attn_output = self.out_proj(attn_output)
|
302 |
+
|
303 |
+
return attn_output
|
304 |
+
|
305 |
+
|
306 |
+
class VisionSdpaAttention(VisionAttention):
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
hidden_states: torch.Tensor,
|
310 |
+
cu_seqlens: torch.Tensor,
|
311 |
+
rotary_pos_emb: torch.Tensor = None,
|
312 |
+
) -> torch.Tensor:
|
313 |
+
if output_attentions:
|
314 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
315 |
+
logger.warning_once(
|
316 |
+
"DAMOVLVisionModel is using VisionSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
317 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
318 |
+
)
|
319 |
+
return super().forward(
|
320 |
+
hidden_states=hidden_states,
|
321 |
+
cu_seqlens=cu_seqlens,
|
322 |
+
rotary_pos_emb=rotary_pos_emb,
|
323 |
+
)
|
324 |
+
|
325 |
+
seq_length = hidden_states.shape[0]
|
326 |
+
query_states = self.q_proj(hidden_states)
|
327 |
+
key_states = self.k_proj(hidden_states)
|
328 |
+
value_states = self.v_proj(hidden_states)
|
329 |
+
|
330 |
+
query_states = query_states.view(q_len, self.num_heads, self.head_dim)
|
331 |
+
key_states = key_states.view(q_len, self.num_heads, self.head_dim)
|
332 |
+
value_states = value_states.view(q_len, self.num_heads, self.head_dim)
|
333 |
+
|
334 |
+
query_states = apply_rotary_pos_emb_vision(query_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
335 |
+
key_states = apply_rotary_pos_emb_vision(key_states.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
336 |
+
|
337 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
338 |
+
for i in range(1, len(cu_seqlens)):
|
339 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
340 |
+
|
341 |
+
query_states = query_states.transpose(0, 1)
|
342 |
+
key_states = key_states.transpose(0, 1)
|
343 |
+
value_states = value_states.transpose(0, 1)
|
344 |
+
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask, dropout_p=0.0)
|
345 |
+
attn_output = attn_output.transpose(0, 1)
|
346 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
347 |
+
attn_output = self.proj(attn_output)
|
348 |
+
return attn_output
|
349 |
+
|
350 |
+
|
351 |
+
DAMOVL_VISION_ATTENTION_CLASSES = {
|
352 |
+
"eager": VisionAttention,
|
353 |
+
"flash_attention_2": VisionFlashAttention2,
|
354 |
+
"sdpa": VisionSdpaAttention,
|
355 |
+
}
|
356 |
+
|
357 |
+
|
358 |
+
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->DAMOVL
|
359 |
+
class DAMOVLVisionMLP(nn.Module):
|
360 |
+
def __init__(self, config):
|
361 |
+
super().__init__()
|
362 |
+
self.config = config
|
363 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
364 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
365 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
366 |
+
|
367 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
368 |
+
hidden_states = self.fc1(hidden_states)
|
369 |
+
hidden_states = self.activation_fn(hidden_states)
|
370 |
+
hidden_states = self.fc2(hidden_states)
|
371 |
+
return hidden_states
|
372 |
+
|
373 |
+
|
374 |
+
class DAMOVLVisionEncoderLayer(nn.Module):
|
375 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
376 |
+
super().__init__()
|
377 |
+
self.embed_dim = config.hidden_size
|
378 |
+
self.self_attn = DAMOVL_VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
|
379 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
380 |
+
self.mlp = DAMOVLVisionMLP(config)
|
381 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
382 |
+
|
383 |
+
# Ignore copy
|
384 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
385 |
+
hidden_states = hidden_states + self.self_attn(
|
386 |
+
self.layer_norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
387 |
+
)
|
388 |
+
hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
|
389 |
+
return hidden_states
|
390 |
+
|
391 |
+
|
392 |
+
class DAMOVLPreTrainedModel(PreTrainedModel):
|
393 |
+
"""
|
394 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
395 |
+
models.
|
396 |
+
"""
|
397 |
+
|
398 |
+
config_class = DAMOVLVisionConfig
|
399 |
+
base_model_prefix = "damovl"
|
400 |
+
supports_gradient_checkpointing = True
|
401 |
+
_no_split_modules = [
|
402 |
+
"DAMOVLVisionEncoderLayer",
|
403 |
+
"DAMOVLVisionEmbeddings",
|
404 |
+
]
|
405 |
+
_supports_flash_attn_2 = True
|
406 |
+
_supports_sdpa = True
|
407 |
+
|
408 |
+
def _init_weights(self, module):
|
409 |
+
"""Initialize the weights"""
|
410 |
+
if isinstance(module, nn.Embedding):
|
411 |
+
default_flax_embed_init(module.weight)
|
412 |
+
elif isinstance(module, VisionAttention):
|
413 |
+
nn.init.xavier_uniform_(module.q_proj.weight)
|
414 |
+
nn.init.xavier_uniform_(module.k_proj.weight)
|
415 |
+
nn.init.xavier_uniform_(module.v_proj.weight)
|
416 |
+
nn.init.xavier_uniform_(module.out_proj.weight)
|
417 |
+
nn.init.zeros_(module.q_proj.bias)
|
418 |
+
nn.init.zeros_(module.k_proj.bias)
|
419 |
+
nn.init.zeros_(module.v_proj.bias)
|
420 |
+
nn.init.zeros_(module.out_proj.bias)
|
421 |
+
elif isinstance(module, DAMOVLVisionMLP):
|
422 |
+
nn.init.xavier_uniform_(module.fc1.weight)
|
423 |
+
nn.init.xavier_uniform_(module.fc2.weight)
|
424 |
+
nn.init.normal_(module.fc1.bias, std=1e-6)
|
425 |
+
nn.init.normal_(module.fc2.bias, std=1e-6)
|
426 |
+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
427 |
+
lecun_normal_(module.weight)
|
428 |
+
if module.bias is not None:
|
429 |
+
nn.init.zeros_(module.bias)
|
430 |
+
elif isinstance(module, nn.LayerNorm):
|
431 |
+
module.bias.data.zero_()
|
432 |
+
module.weight.data.fill_(1.0)
|
433 |
+
|
434 |
+
|
435 |
+
class DAMOVLVisionEncoder(nn.Module):
|
436 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
437 |
+
super().__init__()
|
438 |
+
self.config = config
|
439 |
+
head_dim = config.hidden_size // config.num_attention_heads
|
440 |
+
self.spatial_merge_size = config.spatial_merge_size
|
441 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
442 |
+
self.layers = nn.ModuleList([DAMOVLVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
443 |
+
self.gradient_checkpointing = False
|
444 |
+
|
445 |
+
def rot_pos_emb(self, grid_thw, strides):
|
446 |
+
pos_ids = []
|
447 |
+
for (t, h, w), stride in zip(grid_thw, strides):
|
448 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
449 |
+
hpos_ids = hpos_ids.reshape(
|
450 |
+
h // stride,
|
451 |
+
stride,
|
452 |
+
w // stride,
|
453 |
+
stride,
|
454 |
+
)
|
455 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
456 |
+
hpos_ids = hpos_ids.flatten()
|
457 |
+
|
458 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
459 |
+
wpos_ids = wpos_ids.reshape(
|
460 |
+
h // stride,
|
461 |
+
stride,
|
462 |
+
w // stride,
|
463 |
+
stride,
|
464 |
+
)
|
465 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
466 |
+
wpos_ids = wpos_ids.flatten()
|
467 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
468 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
469 |
+
max_grid_size = grid_thw[:, 1:].max()
|
470 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
471 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
472 |
+
return rotary_pos_emb
|
473 |
+
|
474 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
475 |
+
# BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx`
|
476 |
+
# rotary_pos_emb = []
|
477 |
+
# for thw in grid_thws:
|
478 |
+
# rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0))
|
479 |
+
# rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0)
|
480 |
+
# grid_thws = torch.cat(grid_thws, dim = 0)
|
481 |
+
|
482 |
+
# new version of creating rotary position embedding
|
483 |
+
# grid_thws shapes like [batch_flatten_image_num, 3]
|
484 |
+
# grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py`
|
485 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thws, strides)
|
486 |
+
|
487 |
+
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
488 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
489 |
+
|
490 |
+
for blk in self.layers:
|
491 |
+
if self.gradient_checkpointing and self.training:
|
492 |
+
hidden_states = self._gradient_checkpointing_func(
|
493 |
+
blk.__call__,
|
494 |
+
hidden_states,
|
495 |
+
cu_seqlens,
|
496 |
+
rotary_pos_emb
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
500 |
+
return hidden_states
|
501 |
+
|
502 |
+
|
503 |
+
class DAMOVLVisionTransformer(nn.Module):
|
504 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
505 |
+
super().__init__()
|
506 |
+
self.config = config
|
507 |
+
embed_dim = config.hidden_size
|
508 |
+
|
509 |
+
self.embeddings = DAMOVLVisionEmbeddings(config)
|
510 |
+
self.encoder = DAMOVLVisionEncoder(config)
|
511 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
512 |
+
|
513 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
514 |
+
|
515 |
+
# print(hidden_states)
|
516 |
+
|
517 |
+
# hidden_states = torch.cat(hidden_states, dim = 1)
|
518 |
+
|
519 |
+
hidden_states = self.embeddings(hidden_states)
|
520 |
+
hidden_states = self.encoder(hidden_states, grid_thws, strides)
|
521 |
+
hidden_states = self.post_layernorm(hidden_states)
|
522 |
+
|
523 |
+
return hidden_states
|
524 |
+
|
525 |
+
|
526 |
+
class DAMOVLVisionModel(DAMOVLPreTrainedModel):
|
527 |
+
config_class = DAMOVLVisionConfig
|
528 |
+
main_input_name = "hidden_states"
|
529 |
+
|
530 |
+
def __init__(self, config: DAMOVLVisionConfig):
|
531 |
+
super().__init__(config)
|
532 |
+
|
533 |
+
self.vision_model = DAMOVLVisionTransformer(config)
|
534 |
+
|
535 |
+
# Initialize weights and apply final processing
|
536 |
+
self.post_init()
|
537 |
+
|
538 |
+
def get_input_embeddings(self) -> nn.Module:
|
539 |
+
return self.vision_model.embeddings.patch_embedding
|
540 |
+
|
541 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
542 |
+
return self.vision_model(hidden_states=hidden_states, grid_thws=grid_thws, strides=strides)
|
videollama3/model/encoder.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import (CLIPImageProcessor, CLIPVisionConfig,
|
6 |
+
CLIPVisionModel, SiglipImageProcessor,
|
7 |
+
SiglipVisionConfig, SiglipVisionModel)
|
8 |
+
|
9 |
+
from .qwen2vl_encoder import (Qwen2VisionTransformerPretrainedModel,
|
10 |
+
Qwen2VLImageProcessor, Qwen2VLVisionConfig)
|
11 |
+
|
12 |
+
from .damovl_encoder import (DAMOVLImageProcessor, DAMOVLVisionModel)
|
13 |
+
|
14 |
+
|
15 |
+
class CLIPVisionEncoder(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.is_loaded = False
|
21 |
+
|
22 |
+
self.vision_encoder_name = vision_encoder
|
23 |
+
self.select_layer = args.mm_vision_select_layer
|
24 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
25 |
+
|
26 |
+
if not delay_load:
|
27 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
28 |
+
self.load_model()
|
29 |
+
else:
|
30 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
31 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
32 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
33 |
+
|
34 |
+
def load_model(self):
|
35 |
+
if self.is_loaded:
|
36 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
37 |
+
return
|
38 |
+
|
39 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_encoder_name)
|
40 |
+
|
41 |
+
self.vision_encoder = CLIPVisionModel.from_pretrained(self.vision_encoder_name,
|
42 |
+
attn_implementation=self.attn_implementation)
|
43 |
+
|
44 |
+
self.is_loaded = True
|
45 |
+
|
46 |
+
def feature_select(self, image_forward_outs):
|
47 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
48 |
+
if self.select_feature == 'patch':
|
49 |
+
image_features = image_features[:, 1:]
|
50 |
+
elif self.select_feature == 'cls_patch':
|
51 |
+
image_features = image_features
|
52 |
+
else:
|
53 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
54 |
+
return image_features
|
55 |
+
|
56 |
+
def forward(self, images, **kwargs):
|
57 |
+
images = torch.cat(images)
|
58 |
+
if type(images) is list:
|
59 |
+
image_features = []
|
60 |
+
for image in images:
|
61 |
+
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
62 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
63 |
+
image_features.append(image_feature)
|
64 |
+
else:
|
65 |
+
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
66 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
67 |
+
|
68 |
+
return image_features
|
69 |
+
|
70 |
+
@property
|
71 |
+
def dummy_feature(self):
|
72 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def dtype(self):
|
76 |
+
return self.vision_encoder.dtype
|
77 |
+
|
78 |
+
@property
|
79 |
+
def device(self):
|
80 |
+
return self.vision_encoder.device
|
81 |
+
|
82 |
+
@property
|
83 |
+
def config(self):
|
84 |
+
if self.is_loaded:
|
85 |
+
return self.vision_encoder.config
|
86 |
+
else:
|
87 |
+
return self.cfg_only
|
88 |
+
|
89 |
+
@property
|
90 |
+
def hidden_size(self):
|
91 |
+
return self.config.hidden_size
|
92 |
+
|
93 |
+
@property
|
94 |
+
def num_patches(self):
|
95 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
96 |
+
|
97 |
+
@property
|
98 |
+
def num_patches_per_side(self):
|
99 |
+
return self.config.image_size // self.config.patch_size
|
100 |
+
|
101 |
+
@property
|
102 |
+
def image_size(self):
|
103 |
+
return self.config.image_size
|
104 |
+
|
105 |
+
|
106 |
+
class SiglipVisionEncoder(nn.Module):
|
107 |
+
|
108 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
109 |
+
super().__init__()
|
110 |
+
|
111 |
+
self.is_loaded = False
|
112 |
+
|
113 |
+
self.vision_encoder_name = vision_encoder
|
114 |
+
self.select_layer = args.mm_vision_select_layer
|
115 |
+
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
116 |
+
|
117 |
+
if not delay_load:
|
118 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
119 |
+
self.load_model()
|
120 |
+
else:
|
121 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
122 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
123 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_encoder_name)
|
124 |
+
|
125 |
+
def load_model(self):
|
126 |
+
if self.is_loaded:
|
127 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
128 |
+
return
|
129 |
+
|
130 |
+
self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_encoder_name)
|
131 |
+
|
132 |
+
self.vision_encoder = SiglipVisionModel.from_pretrained(self.vision_encoder_name,
|
133 |
+
attn_implementation=self.attn_implementation)
|
134 |
+
|
135 |
+
self.is_loaded = True
|
136 |
+
|
137 |
+
def feature_select(self, image_forward_outs):
|
138 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
139 |
+
if self.select_feature == 'patch':
|
140 |
+
image_features = image_features
|
141 |
+
else:
|
142 |
+
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
143 |
+
return image_features
|
144 |
+
|
145 |
+
def forward(self, images, **kwargs):
|
146 |
+
images = torch.cat(images)
|
147 |
+
if type(images) is list:
|
148 |
+
image_features = []
|
149 |
+
for image in images:
|
150 |
+
image_forward_out = self.vision_encoder(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
151 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
152 |
+
image_features.append(image_feature)
|
153 |
+
else:
|
154 |
+
image_forward_outs = self.vision_encoder(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
155 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
156 |
+
|
157 |
+
return image_features
|
158 |
+
|
159 |
+
@property
|
160 |
+
def dummy_feature(self):
|
161 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
162 |
+
|
163 |
+
@property
|
164 |
+
def dtype(self):
|
165 |
+
return self.vision_encoder.dtype
|
166 |
+
|
167 |
+
@property
|
168 |
+
def device(self):
|
169 |
+
return self.vision_encoder.device
|
170 |
+
|
171 |
+
@property
|
172 |
+
def config(self):
|
173 |
+
if self.is_loaded:
|
174 |
+
return self.vision_encoder.config
|
175 |
+
else:
|
176 |
+
return self.cfg_only
|
177 |
+
|
178 |
+
@property
|
179 |
+
def hidden_size(self):
|
180 |
+
return self.config.hidden_size
|
181 |
+
|
182 |
+
@property
|
183 |
+
def num_patches(self):
|
184 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
185 |
+
|
186 |
+
@property
|
187 |
+
def num_patches_per_side(self):
|
188 |
+
return self.config.image_size // self.config.patch_size
|
189 |
+
|
190 |
+
@property
|
191 |
+
def image_size(self):
|
192 |
+
return self.config.image_size
|
193 |
+
|
194 |
+
|
195 |
+
class Qwen2VLVisionEncoder(nn.Module):
|
196 |
+
|
197 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
198 |
+
super().__init__()
|
199 |
+
|
200 |
+
self.is_loaded = False
|
201 |
+
|
202 |
+
self.vision_encoder_name = vision_encoder
|
203 |
+
self.select_layer = args.mm_vision_select_layer
|
204 |
+
|
205 |
+
if not delay_load:
|
206 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
207 |
+
self.load_model(args)
|
208 |
+
else:
|
209 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
210 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
211 |
+
self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
|
212 |
+
|
213 |
+
def load_model(self, args):
|
214 |
+
if self.is_loaded:
|
215 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
216 |
+
return
|
217 |
+
|
218 |
+
# merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
|
219 |
+
# for stage 3, the merge_size is set to 2 by argments.
|
220 |
+
self.image_processor = Qwen2VLImageProcessor.from_pretrained(self.vision_encoder_name)
|
221 |
+
self.image_processor.merge_size = args.spatial_merge_size
|
222 |
+
# NOTE: The maximum number of vision tokens is 8192 by default.
|
223 |
+
mm_max_length = args.mm_max_length if hasattr(args, 'mm_max_length') else 9477 // (args.spatial_merge_size**2)
|
224 |
+
self.image_processor.max_pixels = mm_max_length * (args.spatial_merge_size**2 * self.image_processor.patch_size**2)
|
225 |
+
self.image_processor.size["max_pixels"] = self.image_processor.max_pixels
|
226 |
+
|
227 |
+
# merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
|
228 |
+
self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
|
229 |
+
self.cfg_only.spatial_merge_size = args.spatial_merge_size
|
230 |
+
|
231 |
+
self.vision_encoder = Qwen2VisionTransformerPretrainedModel.from_pretrained(
|
232 |
+
self.vision_encoder_name,
|
233 |
+
config=self.cfg_only,
|
234 |
+
torch_dtype=args.torch_dtype,
|
235 |
+
attn_implementation=self.attn_implementation)
|
236 |
+
|
237 |
+
self.is_loaded = True
|
238 |
+
|
239 |
+
def forward(self, images, grid_thws, strides, **kwargs):
|
240 |
+
images = [image for sub_images in images for image in sub_images]
|
241 |
+
grid_thws = [grid_thw for sub_grid_thws in grid_thws for grid_thw in sub_grid_thws]
|
242 |
+
strides = [stride for sub_strides in strides for stride in sub_strides]
|
243 |
+
|
244 |
+
images = torch.cat(images, dim=0)
|
245 |
+
grid_thws = torch.cat(grid_thws, dim=0)
|
246 |
+
|
247 |
+
image_features = self.vision_encoder(images, grid_thws, strides=strides)
|
248 |
+
|
249 |
+
return image_features
|
250 |
+
|
251 |
+
@property
|
252 |
+
def dummy_feature(self):
|
253 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
254 |
+
|
255 |
+
@property
|
256 |
+
def dtype(self):
|
257 |
+
return self.vision_encoder.dtype
|
258 |
+
|
259 |
+
@property
|
260 |
+
def device(self):
|
261 |
+
return self.vision_encoder.device
|
262 |
+
|
263 |
+
@property
|
264 |
+
def config(self):
|
265 |
+
if self.is_loaded:
|
266 |
+
return self.vision_encoder.config
|
267 |
+
else:
|
268 |
+
return self.cfg_only
|
269 |
+
|
270 |
+
@property
|
271 |
+
def hidden_size(self):
|
272 |
+
return self.config.hidden_size
|
273 |
+
|
274 |
+
@property
|
275 |
+
def num_patches(self):
|
276 |
+
return -1
|
277 |
+
|
278 |
+
@property
|
279 |
+
def num_patches_per_side(self):
|
280 |
+
return -1
|
281 |
+
|
282 |
+
@property
|
283 |
+
def image_size(self):
|
284 |
+
return 14 * self.vision_encoder.config.spatial_merge_size
|
285 |
+
|
286 |
+
|
287 |
+
class DAMOVLVisionEncoder(nn.Module):
|
288 |
+
|
289 |
+
def __init__(self, vision_encoder, args, delay_load=False):
|
290 |
+
super().__init__()
|
291 |
+
|
292 |
+
self.is_loaded = False
|
293 |
+
|
294 |
+
self.vision_encoder_name = vision_encoder
|
295 |
+
self.args = args
|
296 |
+
|
297 |
+
if not delay_load:
|
298 |
+
self.attn_implementation = getattr(args, 'mm_attn_implementation', 'flash_attention_2')
|
299 |
+
self.load_model(self.args)
|
300 |
+
else:
|
301 |
+
# uncertain whether flash-attention-2 is supported during inference phase.
|
302 |
+
self.attn_implementation = 'sdpa' # 'eager'
|
303 |
+
self.cfg_only = DAMOVLVisionConfig.from_pretrained(self.vision_encoder_name)
|
304 |
+
|
305 |
+
def load_model(self, args):
|
306 |
+
if self.is_loaded:
|
307 |
+
print('Vision tower is already loaded, `load model` call again, skipping.')
|
308 |
+
return
|
309 |
+
|
310 |
+
# merge_size is set to 1 by default, because STAGE1, STAGE1.5, STAGE2 are trained with merge_size=1
|
311 |
+
# for stage 3, the merge_size is set to 2 by argments.
|
312 |
+
self.image_processor = DAMOVLImageProcessor.from_pretrained(self.vision_encoder_name)
|
313 |
+
self.image_processor.merge_size = args.spatial_merge_size
|
314 |
+
# NOTE: The maximum number of vision tokens is 8192 by default.
|
315 |
+
mm_max_length = args.mm_max_length if hasattr(args, 'mm_max_length') else 9477 // (args.spatial_merge_size**2)
|
316 |
+
self.image_processor.max_pixels = mm_max_length * (args.spatial_merge_size**2 * self.image_processor.patch_size**2)
|
317 |
+
self.image_processor.size["max_pixels"] = self.image_processor.max_pixels
|
318 |
+
|
319 |
+
# merge_size is fixed to 1 for STAGE1, STAGE1.5, STAGE2, STAGE3 in encoder and can be modified in connector.
|
320 |
+
self.cfg_only = Qwen2VLVisionConfig.from_pretrained(self.vision_encoder_name)
|
321 |
+
self.cfg_only.spatial_merge_size = args.spatial_merge_size
|
322 |
+
|
323 |
+
self.vision_encoder = DAMOVLVisionModel.from_pretrained(
|
324 |
+
self.vision_encoder_name,
|
325 |
+
spatial_merge_size=args.spatial_merge_size,
|
326 |
+
torch_dtype=args.torch_dtype,
|
327 |
+
attn_implementation=self.attn_implementation)
|
328 |
+
|
329 |
+
self.is_loaded = True
|
330 |
+
|
331 |
+
def forward(self, images, grid_thws, strides, **kwargs):
|
332 |
+
images = [image for sub_images in images for image in sub_images]
|
333 |
+
grid_thws = [grid_thw for sub_grid_thws in grid_thws for grid_thw in sub_grid_thws]
|
334 |
+
strides = [stride for sub_strides in strides for stride in sub_strides]
|
335 |
+
|
336 |
+
images = torch.cat(images, dim=0)
|
337 |
+
grid_thws = torch.cat(grid_thws, dim=0)
|
338 |
+
|
339 |
+
image_features = self.vision_encoder(images, grid_thws, strides)
|
340 |
+
|
341 |
+
return image_features
|
342 |
+
|
343 |
+
@property
|
344 |
+
def dummy_feature(self):
|
345 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
346 |
+
|
347 |
+
@property
|
348 |
+
def dtype(self):
|
349 |
+
return self.vision_encoder.dtype
|
350 |
+
|
351 |
+
@property
|
352 |
+
def device(self):
|
353 |
+
return self.vision_encoder.device
|
354 |
+
|
355 |
+
@property
|
356 |
+
def config(self):
|
357 |
+
if self.is_loaded:
|
358 |
+
return self.vision_encoder.config
|
359 |
+
else:
|
360 |
+
return self.cfg_only
|
361 |
+
|
362 |
+
@property
|
363 |
+
def hidden_size(self):
|
364 |
+
return self.config.hidden_size
|
365 |
+
|
366 |
+
@property
|
367 |
+
def num_patches(self):
|
368 |
+
return -1
|
369 |
+
|
370 |
+
@property
|
371 |
+
def num_patches_per_side(self):
|
372 |
+
return -1
|
373 |
+
|
374 |
+
@property
|
375 |
+
def image_size(self):
|
376 |
+
return 14 * self.vision_encoder.config.spatial_merge_size
|
377 |
+
|
378 |
+
|
379 |
+
def build_vision_encoder(vision_encoder_cfg, **kwargs):
|
380 |
+
|
381 |
+
vision_encoder = getattr(vision_encoder_cfg, 'mm_vision_encoder', getattr(vision_encoder_cfg, 'vision_encoder', None))
|
382 |
+
|
383 |
+
vision_encoder = DAMOVLVisionEncoder(vision_encoder, args=vision_encoder_cfg, **kwargs)
|
384 |
+
|
385 |
+
return vision_encoder
|
videollama3/model/processor.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""
|
21 |
+
Processor class for VideoLLaMA3.
|
22 |
+
"""
|
23 |
+
import copy
|
24 |
+
import math
|
25 |
+
import warnings
|
26 |
+
from typing import List, Union, Dict, Optional
|
27 |
+
|
28 |
+
import torch
|
29 |
+
from transformers.feature_extraction_utils import BatchFeature
|
30 |
+
from transformers.image_utils import ImageInput, VideoInput
|
31 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
32 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
33 |
+
|
34 |
+
import sys
|
35 |
+
sys.path.append(".")
|
36 |
+
from videollama3.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
|
37 |
+
|
38 |
+
|
39 |
+
DEFAULT_CHAT_TEMPLATE = """
|
40 |
+
{%- set identifier = 'im' %}
|
41 |
+
{% for message in messages %}
|
42 |
+
{% if message['role'] == 'stream' %}
|
43 |
+
{% set identifier = 'stream' %}
|
44 |
+
{% else %}
|
45 |
+
{% set identifier = 'im' %}
|
46 |
+
{% endif %}
|
47 |
+
{{- '<|' + identifier + '_start|>' + message['role'] + '\n' -}}
|
48 |
+
{% if message['content'] is string %}
|
49 |
+
{{- message['content'] + '<|' + identifier + '_end|>\n' -}}
|
50 |
+
{% else %}
|
51 |
+
{% for content in message['content'] %}
|
52 |
+
{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}
|
53 |
+
{% if 'time' in content %}
|
54 |
+
{{- 'Time ' + content['time'] | round(1) | string + 's: ' -}}
|
55 |
+
{% endif %}
|
56 |
+
"""
|
57 |
+
DEFAULT_CHAT_TEMPLATE += """
|
58 |
+
{{- '%s\n' -}}
|
59 |
+
""" % DEFAULT_IMAGE_TOKEN
|
60 |
+
DEFAULT_CHAT_TEMPLATE += """
|
61 |
+
{% elif content['type'] == 'video' or 'video' in content or 'video_url' in content %}
|
62 |
+
{% for i in range(content['num_frames']) %}
|
63 |
+
{% if 'time' in content %}
|
64 |
+
{{- 'Time ' + content['time'][i] | round(1) | string + 's:' -}}
|
65 |
+
{% endif %}
|
66 |
+
{% if i < content['num_frames'] - 1 %}
|
67 |
+
"""
|
68 |
+
DEFAULT_CHAT_TEMPLATE += """
|
69 |
+
{{- '%s,' -}}
|
70 |
+
""" % DEFAULT_IMAGE_TOKEN
|
71 |
+
DEFAULT_CHAT_TEMPLATE += """
|
72 |
+
{% else %}
|
73 |
+
"""
|
74 |
+
DEFAULT_CHAT_TEMPLATE += """
|
75 |
+
{{- '%s\n' -}}
|
76 |
+
""" % DEFAULT_IMAGE_TOKEN
|
77 |
+
DEFAULT_CHAT_TEMPLATE += """
|
78 |
+
{% endif %}
|
79 |
+
{% endfor %}
|
80 |
+
{% elif 'text' in content %}
|
81 |
+
{{- content['text'] -}}
|
82 |
+
{% endif %}
|
83 |
+
{% endfor %}
|
84 |
+
{{- '<|' + identifier + '_end|>\n' -}}
|
85 |
+
{% endif %}
|
86 |
+
{% endfor %}
|
87 |
+
{% if add_generation_prompt %}
|
88 |
+
{{- '<|im_start|>assistant\n' -}}
|
89 |
+
{% endif %}
|
90 |
+
"""
|
91 |
+
|
92 |
+
|
93 |
+
class Videollama3ProcessorKwargs(ProcessingKwargs, total=False):
|
94 |
+
_defaults = {
|
95 |
+
"text_kwargs": {
|
96 |
+
"padding": False,
|
97 |
+
},
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
class Videollama3Processor(ProcessorMixin):
|
102 |
+
r"""
|
103 |
+
Modified from Qwen2VLProcessor
|
104 |
+
Args:
|
105 |
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
106 |
+
The image processor is a required input.
|
107 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
108 |
+
The tokenizer is a required input.
|
109 |
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
110 |
+
in a chat into a tokenizable string.
|
111 |
+
"""
|
112 |
+
|
113 |
+
attributes = ["image_processor", "tokenizer"]
|
114 |
+
valid_kwargs = ["chat_template"]
|
115 |
+
image_processor_class = "Qwen2VLImageProcessor"
|
116 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
117 |
+
|
118 |
+
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
|
119 |
+
if chat_template is None:
|
120 |
+
chat_template = DEFAULT_CHAT_TEMPLATE
|
121 |
+
# super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
122 |
+
tokenizer.chat_template = chat_template
|
123 |
+
self.image_processor = image_processor
|
124 |
+
self.tokenizer = tokenizer
|
125 |
+
self.generation_prompt = self._infer_generation_prompt()
|
126 |
+
self.generation_prompt_ids = self.tokenizer.encode(self.generation_prompt, return_tensors="pt")
|
127 |
+
self.generation_prompt_length = len(self.generation_prompt_ids[0])
|
128 |
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
|
129 |
+
self.eos_token_id = self.tokenizer.eos_token_id
|
130 |
+
|
131 |
+
def get_generation_prompt(self):
|
132 |
+
return self.generation_prompt
|
133 |
+
|
134 |
+
def get_generation_prompt_ids(self):
|
135 |
+
return self.generation_prompt_ids
|
136 |
+
|
137 |
+
def _infer_generation_prompt(self):
|
138 |
+
pseudo_message = [{"role": "user", "content": ""}]
|
139 |
+
instruction = self.tokenizer.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=True)
|
140 |
+
conversation = self.tokenizer.apply_chat_template(pseudo_message, tokenize=False, add_generation_prompt=False)
|
141 |
+
return instruction.replace(conversation, "")
|
142 |
+
|
143 |
+
def _process_text_with_label(
|
144 |
+
self,
|
145 |
+
text: List[Dict],
|
146 |
+
image_grid_thw: torch.Tensor = None,
|
147 |
+
image_downsampling: Optional[int] = None,
|
148 |
+
**kwargs,
|
149 |
+
):
|
150 |
+
assert kwargs.pop("return_tensors", "pt") == "pt", "Only PyTorch tensors are supported when return_labels=True."
|
151 |
+
assert isinstance(text[0], dict), "When return_labels=True, text must be a list of messages."
|
152 |
+
|
153 |
+
input_ids_list = []
|
154 |
+
targets_list = []
|
155 |
+
sample_types_list = []
|
156 |
+
image_idx = 0
|
157 |
+
|
158 |
+
for message_idx, message in enumerate(text):
|
159 |
+
# 1. set chat template and append image tokens
|
160 |
+
prompt = self.tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=False)
|
161 |
+
prompt_chunks = prompt.split(DEFAULT_IMAGE_TOKEN)
|
162 |
+
prompt = []
|
163 |
+
for chunk_idx in range(len(prompt_chunks) - 1):
|
164 |
+
prompt.append(prompt_chunks[chunk_idx])
|
165 |
+
thw = image_grid_thw[image_idx]
|
166 |
+
prompt.append(DEFAULT_IMAGE_TOKEN * (thw.prod() / image_downsampling**2).long())
|
167 |
+
image_idx += 1
|
168 |
+
prompt.append(prompt_chunks[-1])
|
169 |
+
prompt = "".join(prompt)
|
170 |
+
|
171 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")[0]
|
172 |
+
input_ids_list.append(input_ids)
|
173 |
+
|
174 |
+
targets = torch.full_like(input_ids, IGNORE_INDEX)
|
175 |
+
sample_types = torch.full_like(input_ids, IGNORE_INDEX)
|
176 |
+
if message["role"] == "assistant":
|
177 |
+
targets[self.generation_prompt_length:-1] = input_ids[self.generation_prompt_length:-1].clone()
|
178 |
+
elif message["role"] == "stream":
|
179 |
+
diff = torch.diff((input_ids == self.image_token_id).float())
|
180 |
+
image_end_indices = torch.nonzero(diff < 0)[:, 0]
|
181 |
+
targets[image_end_indices + 1] = input_ids[image_end_indices + 1]
|
182 |
+
sample_types = targets.clone()
|
183 |
+
sample_types[torch.logical_and(sample_types > 0, sample_types != self.eos_token_id)] = 0
|
184 |
+
targets[-2] = input_ids[-2] # <|im_end|>
|
185 |
+
|
186 |
+
# if message_idx > 0 and text[message_idx - 1]["role"] == "stream":
|
187 |
+
# targets[0] = input_ids[0]
|
188 |
+
# # TODO: consider non-special tokens
|
189 |
+
# sample_types[0] = input_ids[0]
|
190 |
+
|
191 |
+
targets_list.append(targets)
|
192 |
+
sample_types_list.append(sample_types)
|
193 |
+
|
194 |
+
assert len(image_grid_thw) == image_idx, "Number of images does not match the number of image tokens in the text."
|
195 |
+
|
196 |
+
targets = torch.cat(targets_list)
|
197 |
+
sample_types = torch.cat(sample_types_list)
|
198 |
+
types, counts = torch.unique(sample_types[sample_types > -1], return_counts=True)
|
199 |
+
|
200 |
+
if len(types) > 0:
|
201 |
+
target_num_samples = counts.amin()
|
202 |
+
|
203 |
+
for type_id, type_count in zip(types, counts):
|
204 |
+
if type_count > target_num_samples:
|
205 |
+
indices = torch.nonzero(sample_types == type_id)[:, 0]
|
206 |
+
random_selector = torch.randperm(indices.size(0))[:-target_num_samples]
|
207 |
+
targets[indices[random_selector]] = IGNORE_INDEX
|
208 |
+
sample_types[indices[random_selector]] = -1
|
209 |
+
|
210 |
+
text_inputs = {
|
211 |
+
"input_ids": torch.cat(input_ids_list),
|
212 |
+
"labels": targets,
|
213 |
+
}
|
214 |
+
|
215 |
+
return text_inputs
|
216 |
+
|
217 |
+
def _process_text_without_label(
|
218 |
+
self,
|
219 |
+
text: Union[List[str], List[Dict]],
|
220 |
+
image_grid_thw: torch.Tensor = None,
|
221 |
+
image_downsampling: Optional[int] = None,
|
222 |
+
**kwargs,
|
223 |
+
):
|
224 |
+
if isinstance(text[0], dict):
|
225 |
+
warnings.warn("Input text is a list of messages. Automatically convert it to a string with 'apply_chat_template' with generation prompt.")
|
226 |
+
text = [self.tokenizer.apply_chat_template(text, tokenize=False, add_generation_prompt=True)]
|
227 |
+
|
228 |
+
image_idx = 0
|
229 |
+
for i in range(len(text)):
|
230 |
+
while DEFAULT_IMAGE_TOKEN in text[i]:
|
231 |
+
thw = image_grid_thw[image_idx]
|
232 |
+
text[i] = text[i].replace(DEFAULT_IMAGE_TOKEN, "<placeholder>" * (thw.prod() / image_downsampling**2).long(), 1)
|
233 |
+
image_idx += 1
|
234 |
+
text[i] = text[i].replace("<placeholder>", DEFAULT_IMAGE_TOKEN)
|
235 |
+
assert len(image_grid_thw) == image_idx, "Number of images does not match the number of image tokens in the text."
|
236 |
+
|
237 |
+
text_inputs = self.tokenizer(text, **kwargs)
|
238 |
+
return text_inputs
|
239 |
+
|
240 |
+
def _process_text(
|
241 |
+
self,
|
242 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]],
|
243 |
+
image_grid_thw: torch.Tensor = None,
|
244 |
+
image_downsampling: Optional[int] = None,
|
245 |
+
return_labels: bool = False,
|
246 |
+
**kwargs,
|
247 |
+
):
|
248 |
+
if not isinstance(text, (list, tuple)):
|
249 |
+
text = [text]
|
250 |
+
assert len(text), "At least one text must be provided."
|
251 |
+
|
252 |
+
if return_labels:
|
253 |
+
return self._process_text_with_label(text, image_grid_thw, image_downsampling, **kwargs)
|
254 |
+
return self._process_text_without_label(text, image_grid_thw, image_downsampling, **kwargs)
|
255 |
+
|
256 |
+
def _process_image(
|
257 |
+
self,
|
258 |
+
images: ImageInput = None,
|
259 |
+
image_downsampling: Optional[int] = None,
|
260 |
+
**kwargs,
|
261 |
+
):
|
262 |
+
if image_downsampling is None:
|
263 |
+
image_downsampling = self.image_processor.merge_size
|
264 |
+
|
265 |
+
image_inputs = {
|
266 |
+
"images": [],
|
267 |
+
"grid_thws": [],
|
268 |
+
"image_downsampling": image_downsampling
|
269 |
+
}
|
270 |
+
if images is not None and len(images) > 0:
|
271 |
+
num_images = kwargs.get('num_images', len(images))
|
272 |
+
if 'num_images' in kwargs:
|
273 |
+
kwargs.pop('num_images')
|
274 |
+
for image in images:
|
275 |
+
outputs = self.image_processor(images=image, num_images=num_images, image_downsampling=image_downsampling, **kwargs)
|
276 |
+
# images shapes like: [tensor([patches, 1176]), ...]
|
277 |
+
# grid_thws shapes like: tensor([num_images, 3])
|
278 |
+
|
279 |
+
# flatten_patches1 = outputs["pixel_values"].reshape(26, 46, 3, -1)
|
280 |
+
# from matplotlib import pyplot as plt
|
281 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
282 |
+
# plt.savefig('9.png')
|
283 |
+
|
284 |
+
image_inputs["images"].append(outputs["pixel_values"]) #正常的
|
285 |
+
|
286 |
+
# flatten_patches1 = image_inputs["images"][0].reshape(26, 46, 3, -1)
|
287 |
+
# from matplotlib import pyplot as plt
|
288 |
+
# plt.imshow(flatten_patches1[:,:,:,0])
|
289 |
+
# plt.savefig('12.png')
|
290 |
+
image_inputs["grid_thws"].append(outputs["image_grid_thw"])
|
291 |
+
|
292 |
+
return image_inputs
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
def __call__(
|
297 |
+
self,
|
298 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput], List[Dict]] = None,
|
299 |
+
images: ImageInput = None,
|
300 |
+
image_downsampling: Optional[int] = None,
|
301 |
+
return_labels: bool = False,
|
302 |
+
**kwargs: Unpack[Videollama3ProcessorKwargs],
|
303 |
+
) -> BatchFeature:
|
304 |
+
"""
|
305 |
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
306 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
307 |
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
308 |
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
309 |
+
|
310 |
+
Args:
|
311 |
+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
312 |
+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
313 |
+
tensor. Both channels-first and channels-last formats are supported.
|
314 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
315 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
316 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
317 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
318 |
+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
319 |
+
If set, will return tensors of a particular framework. Acceptable values are:
|
320 |
+
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
321 |
+
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
322 |
+
- `'np'`: Return NumPy `np.ndarray` objects.
|
323 |
+
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
327 |
+
|
328 |
+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
329 |
+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
330 |
+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
331 |
+
`None`).
|
332 |
+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
333 |
+
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
334 |
+
"""
|
335 |
+
output_kwargs = self._merge_kwargs(
|
336 |
+
Videollama3ProcessorKwargs,
|
337 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
338 |
+
**kwargs,
|
339 |
+
)
|
340 |
+
output_kwargs["text_kwargs"].pop("padding")
|
341 |
+
output_kwargs["text_kwargs"].pop("padding_side")
|
342 |
+
|
343 |
+
image_inputs = self._process_image(images, image_downsampling, **output_kwargs["images_kwargs"])
|
344 |
+
text_inputs = self._process_text(text, image_inputs["grid_thws"], image_downsampling, return_labels, **output_kwargs["text_kwargs"])
|
345 |
+
|
346 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
347 |
+
|
348 |
+
def batch_decode(self, *args, **kwargs):
|
349 |
+
"""
|
350 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
351 |
+
refer to the docstring of this method for more information.
|
352 |
+
"""
|
353 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
354 |
+
|
355 |
+
def decode(self, *args, **kwargs):
|
356 |
+
"""
|
357 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
358 |
+
the docstring of this method for more information.
|
359 |
+
"""
|
360 |
+
return self.tokenizer.decode(*args, **kwargs)
|
361 |
+
|
362 |
+
@property
|
363 |
+
def model_input_names(self):
|
364 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
365 |
+
image_processor_input_names = self.image_processor.model_input_names
|
366 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
videollama3/model/projector.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Alibaba DAMO Academy
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
import os
|
17 |
+
import re
|
18 |
+
|
19 |
+
import einops
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from timm.models.layers import LayerNorm, LayerNorm2d
|
24 |
+
from timm.models.regnet import RegStage
|
25 |
+
from transformers import TRANSFORMERS_CACHE
|
26 |
+
|
27 |
+
|
28 |
+
def parse_snapshot_folder(repo_id, cache_dir=None, repo_type="model"):
|
29 |
+
revision = "main"
|
30 |
+
# 1. parse the downloaded cache folder
|
31 |
+
if cache_dir is None:
|
32 |
+
cache_dir = TRANSFORMERS_CACHE
|
33 |
+
else:
|
34 |
+
cache_dir = cache_dir
|
35 |
+
object_id = repo_id.replace("/", "--")
|
36 |
+
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
|
37 |
+
# 2. resolve refs (for instance to convert main to the associated commit sha)
|
38 |
+
refs_dir = os.path.join(repo_cache, "refs")
|
39 |
+
if os.path.isdir(refs_dir):
|
40 |
+
revision_file = os.path.join(refs_dir, revision)
|
41 |
+
if os.path.isfile(revision_file):
|
42 |
+
with open(revision_file) as f:
|
43 |
+
revision = f.read()
|
44 |
+
# 3. acquire the snapshot folder
|
45 |
+
folder = os.path.join(repo_cache, "snapshots", revision)
|
46 |
+
|
47 |
+
return folder
|
48 |
+
|
49 |
+
|
50 |
+
def load_mm_projector(model_path, cache_dir=None, token=None):
|
51 |
+
if os.path.exists(os.path.join(model_path, 'mm_projector.bin')):
|
52 |
+
is_local = True
|
53 |
+
folder = model_path
|
54 |
+
else:
|
55 |
+
is_local = False
|
56 |
+
folder = parse_snapshot_folder(model_path, cache_dir=cache_dir, repo_type="model")
|
57 |
+
if not os.path.exists(os.path.join(folder, 'mm_projector.bin')):
|
58 |
+
# downloading from remote repo
|
59 |
+
from huggingface_hub import snapshot_download
|
60 |
+
snapshot_download(repo_id=model_path, cache_dir=cache_dir, token=token)
|
61 |
+
|
62 |
+
mm_projector_weights = torch.load(os.path.join(folder, 'mm_projector.bin'), map_location='cpu')
|
63 |
+
mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
|
64 |
+
return mm_projector_weights
|
65 |
+
|
66 |
+
|
67 |
+
class IdentityMap(nn.Module):
|
68 |
+
|
69 |
+
def __init__(self):
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
def forward(self, x, *args, **kwargs):
|
73 |
+
return x
|
74 |
+
|
75 |
+
@property
|
76 |
+
def config(self):
|
77 |
+
return {"mm_projector_type": 'identity'}
|
78 |
+
|
79 |
+
|
80 |
+
def build_mlp(depth, hidden_size, output_hidden_size):
|
81 |
+
modules = [nn.Linear(hidden_size, output_hidden_size)]
|
82 |
+
for _ in range(1, depth):
|
83 |
+
modules.append(nn.GELU())
|
84 |
+
modules.append(nn.Linear(output_hidden_size, output_hidden_size))
|
85 |
+
return nn.Sequential(*modules)
|
86 |
+
|
87 |
+
|
88 |
+
class SimSpatialConv(nn.Module):
|
89 |
+
|
90 |
+
def __init__(self, config, downsample=(2, 2), padding=1, depth=1, mlp_depth=2):
|
91 |
+
super().__init__()
|
92 |
+
self.encoder_hidden_size = encoder_hidden_size = config.mm_hidden_size
|
93 |
+
self.output_hidden_size = output_hidden_size = config.hidden_size
|
94 |
+
self.downsample = downsample
|
95 |
+
self.padding = padding
|
96 |
+
self.sampler = nn.Sequential(
|
97 |
+
nn.Conv2d(
|
98 |
+
in_channels=self.encoder_hidden_size,
|
99 |
+
out_channels=4 * self.encoder_hidden_size,
|
100 |
+
kernel_size=self.downsample,
|
101 |
+
stride=self.downsample,
|
102 |
+
padding=self.padding,
|
103 |
+
bias=True
|
104 |
+
),
|
105 |
+
nn.SiLU(),
|
106 |
+
)
|
107 |
+
self.readout = build_mlp(mlp_depth, 4 * self.encoder_hidden_size, self.output_hidden_size)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
hw = int(x.size(1) ** 0.5)
|
111 |
+
x = einops.rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
|
112 |
+
x = self.sampler(x)
|
113 |
+
x = einops.rearrange(x, "b d h w -> b (h w) d")
|
114 |
+
x = self.readout(x)
|
115 |
+
return x
|
116 |
+
|
117 |
+
def cal_proj_size(self, input_size):
|
118 |
+
if isinstance(input_size, int):
|
119 |
+
input_size = (input_size, input_size)
|
120 |
+
height = math.ceil((input_size[0] + self.padding) / self.downsample[0])
|
121 |
+
width = math.ceil((input_size[1] + self.padding) / self.downsample[1])
|
122 |
+
return height * width
|
123 |
+
|
124 |
+
|
125 |
+
class MlpGeluProjector(nn.Module):
|
126 |
+
def __init__(self, config, projector_type):
|
127 |
+
super().__init__()
|
128 |
+
|
129 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
130 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
131 |
+
|
132 |
+
self.readout = build_mlp(mlp_depth, config.mm_hidden_size, config.hidden_size)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
x = self.readout(x)
|
136 |
+
return x
|
137 |
+
|
138 |
+
def cal_proj_size(self, input_size):
|
139 |
+
if isinstance(input_size, int):
|
140 |
+
input_size = (input_size, input_size)
|
141 |
+
height = input_size[0]
|
142 |
+
width = input_size[1]
|
143 |
+
return height * width
|
144 |
+
|
145 |
+
|
146 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
147 |
+
# videollama3 projector only support image-wise operation now, i.e., prohibit the temporal aggregation
|
148 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
149 |
+
|
150 |
+
if projector_type == "linear":
|
151 |
+
# NOTE: for both linear and mlp2x_gelu projector type, mean pooling is adopted to aggreate video features
|
152 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
153 |
+
elif projector_type == "simp_spatial_conv":
|
154 |
+
return SimSpatialConv(config)
|
155 |
+
elif projector_type.startswith("mlp"):
|
156 |
+
return MlpGeluProjector(config, projector_type)
|
157 |
+
if projector_type == 'identity':
|
158 |
+
return IdentityMap()
|
159 |
+
|
160 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
videollama3/model/qwen2vl_encoder/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig
|
2 |
+
from .image_processing import Qwen2VLImageProcessor
|
3 |
+
from .modeling_qwen2vl_encoder import Qwen2VisionTransformerPretrainedModel
|
videollama3/model/qwen2vl_encoder/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (432 Bytes). View file
|
|
videollama3/model/qwen2vl_encoder/__pycache__/configuration_qwen2vl_encoder.cpython-310.pyc
ADDED
Binary file (1.92 kB). View file
|
|
videollama3/model/qwen2vl_encoder/__pycache__/image_processing.cpython-310.pyc
ADDED
Binary file (16.9 kB). View file
|
|
videollama3/model/qwen2vl_encoder/__pycache__/modeling_qwen2vl_encoder.cpython-310.pyc
ADDED
Binary file (12.7 kB). View file
|
|
videollama3/model/qwen2vl_encoder/configuration_qwen2vl_encoder.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Qwen2VL model configuration"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import Union
|
19 |
+
|
20 |
+
from transformers.configuration_utils import PretrainedConfig
|
21 |
+
from transformers.utils import logging
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
28 |
+
model_type = "qwen2_vl"
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
depth=32,
|
33 |
+
embed_dim=1280,
|
34 |
+
hidden_size=3584,
|
35 |
+
hidden_act="quick_gelu",
|
36 |
+
mlp_ratio=4,
|
37 |
+
num_heads=16,
|
38 |
+
in_channels=3,
|
39 |
+
patch_size=14,
|
40 |
+
spatial_merge_size=2,
|
41 |
+
temporal_patch_size=2,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__(**kwargs)
|
45 |
+
|
46 |
+
self.depth = depth
|
47 |
+
self.embed_dim = embed_dim
|
48 |
+
self.hidden_size = hidden_size
|
49 |
+
self.hidden_act = hidden_act
|
50 |
+
self.mlp_ratio = mlp_ratio
|
51 |
+
self.num_heads = num_heads
|
52 |
+
self.in_channels = in_channels
|
53 |
+
self.patch_size = patch_size
|
54 |
+
self.spatial_merge_size = spatial_merge_size
|
55 |
+
self.temporal_patch_size = temporal_patch_size
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
59 |
+
cls._set_token_in_kwargs(kwargs)
|
60 |
+
|
61 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
62 |
+
|
63 |
+
# if config_dict.get("model_type") == "qwen2_vl":
|
64 |
+
# config_dict = config_dict["vision_config"]
|
65 |
+
|
66 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
67 |
+
logger.warning(
|
68 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
69 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
70 |
+
)
|
71 |
+
|
72 |
+
return cls.from_dict(config_dict, **kwargs)
|
videollama3/model/qwen2vl_encoder/image_processing.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""Image processor class for Qwen2-VL."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from typing import Dict, List, Optional, Union
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
28 |
+
from transformers.image_transforms import (
|
29 |
+
convert_to_rgb,
|
30 |
+
resize,
|
31 |
+
to_channel_dimension_format,
|
32 |
+
)
|
33 |
+
from transformers.image_utils import (
|
34 |
+
OPENAI_CLIP_MEAN,
|
35 |
+
OPENAI_CLIP_STD,
|
36 |
+
ChannelDimension,
|
37 |
+
ImageInput,
|
38 |
+
PILImageResampling,
|
39 |
+
VideoInput,
|
40 |
+
get_image_size,
|
41 |
+
infer_channel_dimension_format,
|
42 |
+
is_scaled_image,
|
43 |
+
is_valid_image,
|
44 |
+
make_list_of_images,
|
45 |
+
to_numpy_array,
|
46 |
+
valid_images,
|
47 |
+
validate_preprocess_arguments,
|
48 |
+
)
|
49 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
50 |
+
|
51 |
+
|
52 |
+
logger = logging.get_logger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
if is_vision_available():
|
56 |
+
from PIL import Image
|
57 |
+
|
58 |
+
|
59 |
+
def make_batched_images(images) -> List[List[ImageInput]]:
|
60 |
+
"""
|
61 |
+
Accepts images in list or nested list format, and makes a list of images for preprocessing.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
|
65 |
+
The input image.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
list: A list of images.
|
69 |
+
"""
|
70 |
+
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
|
71 |
+
return [img for img_list in images for img in img_list]
|
72 |
+
|
73 |
+
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
|
74 |
+
return images
|
75 |
+
|
76 |
+
elif is_valid_image(images):
|
77 |
+
return [images]
|
78 |
+
|
79 |
+
raise ValueError(f"Could not make batched images from {images}")
|
80 |
+
|
81 |
+
|
82 |
+
# Copied from transformers.models.llava_next_video.image_processing_llava_next_video.make_batched_videos
|
83 |
+
def make_batched_videos(videos) -> List[VideoInput]:
|
84 |
+
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
85 |
+
return videos
|
86 |
+
|
87 |
+
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
88 |
+
if isinstance(videos[0], Image.Image):
|
89 |
+
return [videos]
|
90 |
+
elif len(videos[0].shape) == 4:
|
91 |
+
return [list(video) for video in videos]
|
92 |
+
|
93 |
+
elif is_valid_image(videos) and len(videos.shape) == 4:
|
94 |
+
return [list(videos)]
|
95 |
+
|
96 |
+
raise ValueError(f"Could not make batched video from {videos}")
|
97 |
+
|
98 |
+
|
99 |
+
def smart_resize(
|
100 |
+
height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280
|
101 |
+
):
|
102 |
+
"""Rescales the image so that the following conditions are met:
|
103 |
+
|
104 |
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
105 |
+
|
106 |
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
107 |
+
|
108 |
+
3. The aspect ratio of the image is maintained as closely as possible.
|
109 |
+
|
110 |
+
"""
|
111 |
+
if height < factor or width < factor:
|
112 |
+
scale = factor / min(height, width)
|
113 |
+
width = round(scale * width)
|
114 |
+
height = round(scale * height)
|
115 |
+
elif max(height, width) / min(height, width) > 200:
|
116 |
+
raise ValueError(
|
117 |
+
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
118 |
+
)
|
119 |
+
h_bar = round(height / factor) * factor
|
120 |
+
w_bar = round(width / factor) * factor
|
121 |
+
if h_bar * w_bar > max_pixels:
|
122 |
+
beta = math.sqrt((height * width) / max_pixels)
|
123 |
+
h_bar = math.floor(height / beta / factor) * factor
|
124 |
+
w_bar = math.floor(width / beta / factor) * factor
|
125 |
+
elif h_bar * w_bar < min_pixels:
|
126 |
+
beta = math.sqrt(min_pixels / (height * width))
|
127 |
+
h_bar = math.ceil(height * beta / factor) * factor
|
128 |
+
w_bar = math.ceil(width * beta / factor) * factor
|
129 |
+
return h_bar, w_bar
|
130 |
+
|
131 |
+
|
132 |
+
class Qwen2VLImageProcessor(BaseImageProcessor):
|
133 |
+
r"""
|
134 |
+
Constructs a Qwen2-VL image processor that dynamically resizes images based on the original images.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
138 |
+
Whether to resize the image's (height, width) dimensions.
|
139 |
+
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
|
140 |
+
Resampling filter to use when resizing the image.
|
141 |
+
do_rescale (`bool`, *optional*, defaults to `True`):
|
142 |
+
Whether to rescale the image by the specified scale `rescale_factor`.
|
143 |
+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
144 |
+
Scale factor to use if rescaling the image.
|
145 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
146 |
+
Whether to normalize the image.
|
147 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
|
148 |
+
Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
149 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
|
150 |
+
Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
|
151 |
+
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
152 |
+
Whether to convert the image to RGB.
|
153 |
+
min_pixels (`int`, *optional*, defaults to `56 * 56`):
|
154 |
+
The min pixels of the image to resize the image.
|
155 |
+
max_pixels (`int`, *optional*, defaults to `28 * 28 * 1280`):
|
156 |
+
The max pixels of the image to resize the image.
|
157 |
+
patch_size (`int`, *optional*, defaults to 14):
|
158 |
+
The spacial patch size of the vision encoder.
|
159 |
+
temporal_patch_size (`int`, *optional*, defaults to 2):
|
160 |
+
The temporal patch size of the vision encoder.
|
161 |
+
merge_size (`int`, *optional*, defaults to 2):
|
162 |
+
The merge size of the vision encoder to llm encoder.
|
163 |
+
"""
|
164 |
+
|
165 |
+
model_input_names = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
|
166 |
+
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
do_resize: bool = True,
|
170 |
+
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
171 |
+
do_rescale: bool = True,
|
172 |
+
rescale_factor: Union[int, float] = 1 / 255,
|
173 |
+
do_normalize: bool = True,
|
174 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
175 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
176 |
+
do_convert_rgb: bool = True,
|
177 |
+
min_pixels: int = 56 * 56,
|
178 |
+
max_pixels: int = 28 * 28 * 1280,
|
179 |
+
patch_size: int = 14,
|
180 |
+
temporal_patch_size: int = 2,
|
181 |
+
merge_size: int = 2,
|
182 |
+
**kwargs,
|
183 |
+
) -> None:
|
184 |
+
super().__init__(**kwargs)
|
185 |
+
self.do_resize = do_resize
|
186 |
+
self.resample = resample
|
187 |
+
self.do_rescale = do_rescale
|
188 |
+
self.rescale_factor = rescale_factor
|
189 |
+
self.do_normalize = do_normalize
|
190 |
+
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
191 |
+
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
192 |
+
self.min_pixels = min_pixels
|
193 |
+
self.max_pixels = max_pixels
|
194 |
+
self.patch_size = patch_size
|
195 |
+
self.temporal_patch_size = temporal_patch_size
|
196 |
+
self.merge_size = merge_size
|
197 |
+
self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
|
198 |
+
self.do_convert_rgb = do_convert_rgb
|
199 |
+
|
200 |
+
def _preprocess(
|
201 |
+
self,
|
202 |
+
images: Union[ImageInput, VideoInput],
|
203 |
+
do_resize: bool = None,
|
204 |
+
resample: PILImageResampling = None,
|
205 |
+
do_rescale: bool = None,
|
206 |
+
rescale_factor: float = None,
|
207 |
+
do_normalize: bool = None,
|
208 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
209 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
210 |
+
do_convert_rgb: bool = None,
|
211 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
212 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
213 |
+
num_images: Optional[int] = 1,
|
214 |
+
image_downsampling: Optional[int] = None,
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
images (`ImageInput`):
|
221 |
+
Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
|
222 |
+
vision_info (`List[Dict]`, *optional*):
|
223 |
+
Optional list of dictionaries containing additional information about vision inputs.
|
224 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
225 |
+
Whether to resize the image.
|
226 |
+
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
227 |
+
Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
|
228 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
229 |
+
Whether to rescale the image.
|
230 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
231 |
+
Scale factor to use if rescaling the image.
|
232 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
233 |
+
Whether to normalize the image.
|
234 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
235 |
+
Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
236 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
237 |
+
Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
|
238 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
239 |
+
Whether to convert the image to RGB.
|
240 |
+
data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
|
241 |
+
The channel dimension format for the output image. Can be one of:
|
242 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
243 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
244 |
+
- Unset: Use the channel dimension format of the input image.
|
245 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
246 |
+
The channel dimension format for the input image. Can be one of:
|
247 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
248 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
249 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
250 |
+
"""
|
251 |
+
images = make_list_of_images(images)
|
252 |
+
|
253 |
+
if do_convert_rgb:
|
254 |
+
images = [convert_to_rgb(image) for image in images]
|
255 |
+
|
256 |
+
# All transformations expect numpy arrays.
|
257 |
+
images = [to_numpy_array(image) for image in images]
|
258 |
+
|
259 |
+
if is_scaled_image(images[0]) and do_rescale:
|
260 |
+
logger.warning_once(
|
261 |
+
"It looks like you are trying to rescale already rescaled images. If the input"
|
262 |
+
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
263 |
+
)
|
264 |
+
if input_data_format is None:
|
265 |
+
# We assume that all images have the same channel dimension format.
|
266 |
+
input_data_format = infer_channel_dimension_format(images[0])
|
267 |
+
|
268 |
+
height, width = get_image_size(images[0], channel_dim=input_data_format)
|
269 |
+
resized_height, resized_width = height, width
|
270 |
+
processed_images = []
|
271 |
+
for image in images:
|
272 |
+
if do_resize:
|
273 |
+
max_pixels = int(self.max_pixels / (self.merge_size / image_downsampling)**2)
|
274 |
+
resized_height, resized_width = smart_resize(
|
275 |
+
height,
|
276 |
+
width,
|
277 |
+
factor=self.patch_size * image_downsampling,
|
278 |
+
min_pixels=self.min_pixels,
|
279 |
+
max_pixels=int(max_pixels // num_images),
|
280 |
+
)
|
281 |
+
image = resize(
|
282 |
+
image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
|
283 |
+
)
|
284 |
+
|
285 |
+
if do_rescale:
|
286 |
+
image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
|
287 |
+
|
288 |
+
if do_normalize:
|
289 |
+
image = self.normalize(
|
290 |
+
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
|
291 |
+
)
|
292 |
+
|
293 |
+
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
294 |
+
processed_images.append(image)
|
295 |
+
|
296 |
+
patches = np.array(processed_images)
|
297 |
+
if data_format == ChannelDimension.LAST:
|
298 |
+
patches = patches.transpose(0, 3, 1, 2)
|
299 |
+
if patches.shape[0] == 1:
|
300 |
+
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
|
301 |
+
channel = patches.shape[1]
|
302 |
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
303 |
+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
|
304 |
+
patches = patches.reshape(
|
305 |
+
grid_t,
|
306 |
+
self.temporal_patch_size,
|
307 |
+
channel,
|
308 |
+
grid_h // image_downsampling,
|
309 |
+
image_downsampling,
|
310 |
+
self.patch_size,
|
311 |
+
grid_w // image_downsampling,
|
312 |
+
image_downsampling,
|
313 |
+
self.patch_size,
|
314 |
+
)
|
315 |
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
316 |
+
flatten_patches = patches.reshape(
|
317 |
+
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
318 |
+
)
|
319 |
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
320 |
+
|
321 |
+
def preprocess(
|
322 |
+
self,
|
323 |
+
images: ImageInput,
|
324 |
+
videos: VideoInput = None,
|
325 |
+
do_resize: bool = None,
|
326 |
+
size: Dict[str, int] = None,
|
327 |
+
resample: PILImageResampling = None,
|
328 |
+
do_rescale: bool = None,
|
329 |
+
rescale_factor: float = None,
|
330 |
+
do_normalize: bool = None,
|
331 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
332 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
333 |
+
do_convert_rgb: bool = None,
|
334 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
335 |
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
336 |
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
337 |
+
num_images: Optional[int] = 1,
|
338 |
+
image_downsampling: Optional[int] = None,
|
339 |
+
):
|
340 |
+
"""
|
341 |
+
Args:
|
342 |
+
images (`ImageInput`):
|
343 |
+
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
344 |
+
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
345 |
+
videos (`VideoInput`):
|
346 |
+
Video to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
|
347 |
+
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
|
348 |
+
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
349 |
+
Whether to resize the image.
|
350 |
+
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
351 |
+
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
|
352 |
+
the longest edge resized to keep the input aspect ratio.
|
353 |
+
resample (`int`, *optional*, defaults to `self.resample`):
|
354 |
+
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
|
355 |
+
has an effect if `do_resize` is set to `True`.
|
356 |
+
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
357 |
+
Whether to rescale the image.
|
358 |
+
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
359 |
+
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
360 |
+
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
361 |
+
Whether to normalize the image.
|
362 |
+
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
363 |
+
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
|
364 |
+
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
365 |
+
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
|
366 |
+
`True`.
|
367 |
+
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
368 |
+
Whether to convert the image to RGB.
|
369 |
+
return_tensors (`str` or `TensorType`, *optional*):
|
370 |
+
The type of tensors to return. Can be one of:
|
371 |
+
- Unset: Return a list of `np.ndarray`.
|
372 |
+
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
373 |
+
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
374 |
+
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
375 |
+
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
376 |
+
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
377 |
+
The channel dimension format for the output image. Can be one of:
|
378 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
379 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
380 |
+
- Unset: Use the channel dimension format of the input image.
|
381 |
+
input_data_format (`ChannelDimension` or `str`, *optional*):
|
382 |
+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
383 |
+
from the input image. Can be one of:
|
384 |
+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
385 |
+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
386 |
+
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
387 |
+
|
388 |
+
"""
|
389 |
+
do_resize = do_resize if do_resize is not None else self.do_resize
|
390 |
+
size = size if size is not None else self.size
|
391 |
+
resample = resample if resample is not None else self.resample
|
392 |
+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
393 |
+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
394 |
+
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
395 |
+
image_mean = image_mean if image_mean is not None else self.image_mean
|
396 |
+
image_std = image_std if image_std is not None else self.image_std
|
397 |
+
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
398 |
+
image_downsampling = image_downsampling if image_downsampling is not None else self.merge_size
|
399 |
+
|
400 |
+
if images is not None:
|
401 |
+
images = make_batched_images(images)
|
402 |
+
if videos is not None:
|
403 |
+
videos = make_batched_videos(videos)
|
404 |
+
|
405 |
+
if images is not None and not valid_images(images):
|
406 |
+
raise ValueError(
|
407 |
+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
408 |
+
"torch.Tensor, tf.Tensor or jax.ndarray."
|
409 |
+
)
|
410 |
+
|
411 |
+
validate_preprocess_arguments(
|
412 |
+
rescale_factor=rescale_factor,
|
413 |
+
do_normalize=do_normalize,
|
414 |
+
image_mean=image_mean,
|
415 |
+
image_std=image_std,
|
416 |
+
do_resize=do_resize,
|
417 |
+
size=size,
|
418 |
+
resample=resample,
|
419 |
+
)
|
420 |
+
|
421 |
+
if images is not None:
|
422 |
+
pixel_values, vision_grid_thws = [], []
|
423 |
+
for image in images:
|
424 |
+
patches, image_grid_thw = self._preprocess(
|
425 |
+
image,
|
426 |
+
do_resize=do_resize,
|
427 |
+
resample=resample,
|
428 |
+
do_rescale=do_rescale,
|
429 |
+
rescale_factor=rescale_factor,
|
430 |
+
do_normalize=do_normalize,
|
431 |
+
image_mean=image_mean,
|
432 |
+
image_std=image_std,
|
433 |
+
data_format=data_format,
|
434 |
+
do_convert_rgb=do_convert_rgb,
|
435 |
+
input_data_format=input_data_format,
|
436 |
+
num_images=num_images,
|
437 |
+
image_downsampling=image_downsampling,
|
438 |
+
)
|
439 |
+
pixel_values.extend(patches)
|
440 |
+
vision_grid_thws.append(image_grid_thw)
|
441 |
+
pixel_values = np.array(pixel_values)
|
442 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
443 |
+
data = {"pixel_values": pixel_values, "image_grid_thw": vision_grid_thws}
|
444 |
+
|
445 |
+
if videos is not None:
|
446 |
+
pixel_values, vision_grid_thws = [], []
|
447 |
+
for images in videos:
|
448 |
+
patches, video_grid_thw = self._preprocess(
|
449 |
+
images,
|
450 |
+
do_resize=do_resize,
|
451 |
+
resample=resample,
|
452 |
+
do_rescale=do_rescale,
|
453 |
+
rescale_factor=rescale_factor,
|
454 |
+
do_normalize=do_normalize,
|
455 |
+
image_mean=image_mean,
|
456 |
+
image_std=image_std,
|
457 |
+
data_format=data_format,
|
458 |
+
do_convert_rgb=do_convert_rgb,
|
459 |
+
input_data_format=input_data_format,
|
460 |
+
num_images=num_images,
|
461 |
+
image_downsampling=image_downsampling,
|
462 |
+
)
|
463 |
+
pixel_values.extend(patches)
|
464 |
+
vision_grid_thws.append(video_grid_thw)
|
465 |
+
pixel_values = np.array(pixel_values)
|
466 |
+
vision_grid_thws = np.array(vision_grid_thws)
|
467 |
+
data = {"pixel_values_videos": pixel_values, "video_grid_thw": vision_grid_thws}
|
468 |
+
|
469 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
videollama3/model/qwen2vl_encoder/modeling_qwen2vl_encoder.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""PyTorch Qwen2-VL model."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
from dataclasses import dataclass
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.utils.checkpoint
|
30 |
+
from torch.nn import CrossEntropyLoss, LayerNorm
|
31 |
+
from transformers.activations import ACT2FN
|
32 |
+
from transformers.cache_utils import Cache, StaticCache
|
33 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
34 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
from transformers.utils import (add_start_docstrings,
|
37 |
+
add_start_docstrings_to_model_forward,
|
38 |
+
is_flash_attn_2_available,
|
39 |
+
is_flash_attn_greater_or_equal_2_10, logging,
|
40 |
+
replace_return_docstrings)
|
41 |
+
|
42 |
+
from .configuration_qwen2vl_encoder import Qwen2VLVisionConfig
|
43 |
+
|
44 |
+
if is_flash_attn_2_available():
|
45 |
+
from flash_attn import flash_attn_varlen_func
|
46 |
+
from transformers.modeling_flash_attention_utils import \
|
47 |
+
_flash_attention_forward
|
48 |
+
else:
|
49 |
+
flash_attn_varlen_func = None
|
50 |
+
|
51 |
+
logger = logging.get_logger(__name__)
|
52 |
+
|
53 |
+
|
54 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
55 |
+
def rotate_half(x):
|
56 |
+
"""Rotates half the hidden dims of the input."""
|
57 |
+
x1 = x[..., : x.shape[-1] // 2]
|
58 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
59 |
+
return torch.cat((-x2, x1), dim=-1)
|
60 |
+
|
61 |
+
|
62 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
63 |
+
orig_dtype = tensor.dtype
|
64 |
+
tensor = tensor.float()
|
65 |
+
cos = freqs.cos()
|
66 |
+
sin = freqs.sin()
|
67 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
68 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
69 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
70 |
+
output = output.to(orig_dtype)
|
71 |
+
return output
|
72 |
+
|
73 |
+
|
74 |
+
class VisionRotaryEmbedding(nn.Module):
|
75 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
76 |
+
super().__init__()
|
77 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
78 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
79 |
+
|
80 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
81 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
82 |
+
freqs = torch.outer(seq, self.inv_freq)
|
83 |
+
return freqs
|
84 |
+
|
85 |
+
|
86 |
+
class PatchEmbed(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
patch_size: int = 14,
|
90 |
+
temporal_patch_size: int = 2,
|
91 |
+
in_channels: int = 3,
|
92 |
+
embed_dim: int = 1152,
|
93 |
+
) -> None:
|
94 |
+
super().__init__()
|
95 |
+
self.patch_size = patch_size
|
96 |
+
self.temporal_patch_size = temporal_patch_size
|
97 |
+
self.in_channels = in_channels
|
98 |
+
self.embed_dim = embed_dim
|
99 |
+
|
100 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
101 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
102 |
+
|
103 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
104 |
+
target_dtype = self.proj.weight.dtype
|
105 |
+
hidden_states = hidden_states.view(
|
106 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
107 |
+
)
|
108 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
109 |
+
return hidden_states
|
110 |
+
|
111 |
+
|
112 |
+
class PatchMerger(nn.Module):
|
113 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
114 |
+
super().__init__()
|
115 |
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
116 |
+
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
117 |
+
self.mlp = nn.Sequential(
|
118 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
119 |
+
nn.GELU(),
|
120 |
+
nn.Linear(self.hidden_size, dim),
|
121 |
+
)
|
122 |
+
|
123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
125 |
+
return x
|
126 |
+
|
127 |
+
|
128 |
+
class VisionMlp(nn.Module):
|
129 |
+
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
|
130 |
+
super().__init__()
|
131 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
132 |
+
self.act = ACT2FN[hidden_act]
|
133 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
134 |
+
|
135 |
+
def forward(self, x) -> torch.Tensor:
|
136 |
+
return self.fc2(self.act(self.fc1(x)))
|
137 |
+
|
138 |
+
|
139 |
+
class VisionAttention(nn.Module):
|
140 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
141 |
+
super().__init__()
|
142 |
+
self.num_heads = num_heads
|
143 |
+
self.head_dim = dim // num_heads
|
144 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
145 |
+
self.proj = nn.Linear(dim, dim)
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
149 |
+
) -> torch.Tensor:
|
150 |
+
seq_length = hidden_states.shape[0]
|
151 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
152 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
153 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
154 |
+
|
155 |
+
attention_mask = torch.full(
|
156 |
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
157 |
+
)
|
158 |
+
for i in range(1, len(cu_seqlens)):
|
159 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
160 |
+
|
161 |
+
q = q.transpose(0, 1)
|
162 |
+
k = k.transpose(0, 1)
|
163 |
+
v = v.transpose(0, 1)
|
164 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
165 |
+
attn_weights = attn_weights + attention_mask
|
166 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
167 |
+
attn_output = torch.matmul(attn_weights, v)
|
168 |
+
attn_output = attn_output.transpose(0, 1)
|
169 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
170 |
+
attn_output = self.proj(attn_output)
|
171 |
+
return attn_output
|
172 |
+
|
173 |
+
|
174 |
+
class VisionFlashAttention2(nn.Module):
|
175 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
176 |
+
super().__init__()
|
177 |
+
self.num_heads = num_heads
|
178 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
179 |
+
self.proj = nn.Linear(dim, dim)
|
180 |
+
|
181 |
+
def forward(
|
182 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
183 |
+
) -> torch.Tensor:
|
184 |
+
seq_length = hidden_states.shape[0]
|
185 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
186 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
187 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
188 |
+
|
189 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
190 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
191 |
+
seq_length, -1
|
192 |
+
)
|
193 |
+
attn_output = self.proj(attn_output)
|
194 |
+
return attn_output
|
195 |
+
|
196 |
+
|
197 |
+
class VisionSdpaAttention(nn.Module):
|
198 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
199 |
+
super().__init__()
|
200 |
+
self.num_heads = num_heads
|
201 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
202 |
+
self.proj = nn.Linear(dim, dim)
|
203 |
+
|
204 |
+
def forward(
|
205 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
206 |
+
) -> torch.Tensor:
|
207 |
+
seq_length = hidden_states.shape[0]
|
208 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
209 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
210 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
211 |
+
|
212 |
+
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
|
213 |
+
for i in range(1, len(cu_seqlens)):
|
214 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
|
215 |
+
q = q.transpose(0, 1)
|
216 |
+
k = k.transpose(0, 1)
|
217 |
+
v = v.transpose(0, 1)
|
218 |
+
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
219 |
+
attn_output = attn_output.transpose(0, 1)
|
220 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
221 |
+
attn_output = self.proj(attn_output)
|
222 |
+
return attn_output
|
223 |
+
|
224 |
+
|
225 |
+
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
226 |
+
"eager": VisionAttention,
|
227 |
+
"flash_attention_2": VisionFlashAttention2,
|
228 |
+
"sdpa": VisionSdpaAttention,
|
229 |
+
}
|
230 |
+
|
231 |
+
|
232 |
+
class Qwen2VLVisionBlock(nn.Module):
|
233 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
234 |
+
super().__init__()
|
235 |
+
self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
|
236 |
+
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
|
237 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
238 |
+
|
239 |
+
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
240 |
+
config.embed_dim, num_heads=config.num_heads
|
241 |
+
)
|
242 |
+
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
243 |
+
|
244 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
245 |
+
hidden_states = hidden_states + self.attn(
|
246 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
247 |
+
)
|
248 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
249 |
+
return hidden_states
|
250 |
+
|
251 |
+
|
252 |
+
class Qwen2VLPreTrainedModel(PreTrainedModel):
|
253 |
+
config_class = Qwen2VLVisionConfig
|
254 |
+
base_model_prefix = "model"
|
255 |
+
supports_gradient_checkpointing = True
|
256 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
257 |
+
_skip_keys_device_placement = "past_key_values"
|
258 |
+
_supports_flash_attn_2 = True
|
259 |
+
_supports_sdpa = True
|
260 |
+
_supports_cache_class = True
|
261 |
+
_supports_static_cache = True
|
262 |
+
|
263 |
+
def _init_weights(self, module):
|
264 |
+
std = self.config.initializer_range
|
265 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
266 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
267 |
+
if module.bias is not None:
|
268 |
+
module.bias.data.zero_()
|
269 |
+
elif isinstance(module, nn.Embedding):
|
270 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
271 |
+
if module.padding_idx is not None:
|
272 |
+
module.weight.data[module.padding_idx].zero_()
|
273 |
+
|
274 |
+
|
275 |
+
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
276 |
+
config_class = Qwen2VLVisionConfig
|
277 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
278 |
+
|
279 |
+
def __init__(self, config) -> None:
|
280 |
+
super().__init__(config)
|
281 |
+
self.spatial_merge_size = config.spatial_merge_size
|
282 |
+
self.gradient_checkpointing = False
|
283 |
+
|
284 |
+
self.patch_embed = PatchEmbed(
|
285 |
+
patch_size=config.patch_size,
|
286 |
+
temporal_patch_size=config.temporal_patch_size,
|
287 |
+
in_channels=config.in_channels,
|
288 |
+
embed_dim=config.embed_dim,
|
289 |
+
)
|
290 |
+
|
291 |
+
head_dim = config.embed_dim // config.num_heads
|
292 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
293 |
+
|
294 |
+
self.blocks = nn.ModuleList(
|
295 |
+
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
296 |
+
)
|
297 |
+
#
|
298 |
+
# if self.spatial_merge_size > 1:
|
299 |
+
# self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
|
300 |
+
|
301 |
+
def get_dtype(self) -> torch.dtype:
|
302 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
303 |
+
|
304 |
+
def get_device(self) -> torch.device:
|
305 |
+
return self.blocks[0].mlp.fc2.weight.device
|
306 |
+
|
307 |
+
def rot_pos_emb(self, grid_thw, strides):
|
308 |
+
pos_ids = []
|
309 |
+
for (t, h, w), stride in zip(grid_thw, strides):
|
310 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
311 |
+
hpos_ids = hpos_ids.reshape(
|
312 |
+
h // stride,
|
313 |
+
stride,
|
314 |
+
w // stride,
|
315 |
+
stride,
|
316 |
+
)
|
317 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
318 |
+
hpos_ids = hpos_ids.flatten()
|
319 |
+
|
320 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
321 |
+
wpos_ids = wpos_ids.reshape(
|
322 |
+
h // stride,
|
323 |
+
stride,
|
324 |
+
w // stride,
|
325 |
+
stride,
|
326 |
+
)
|
327 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
328 |
+
wpos_ids = wpos_ids.flatten()
|
329 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
330 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
331 |
+
max_grid_size = grid_thw[:, 1:].max()
|
332 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
333 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
334 |
+
return rotary_pos_emb
|
335 |
+
|
336 |
+
def forward(self, hidden_states, grid_thws, strides) -> torch.Tensor:
|
337 |
+
hidden_states = self.patch_embed(hidden_states)
|
338 |
+
|
339 |
+
# BUG: These codes will cause deepspeed issue: `RuntimeError: disagreement between rank0 and rankx`
|
340 |
+
# rotary_pos_emb = []
|
341 |
+
# for thw in grid_thws:
|
342 |
+
# rotary_pos_emb.append(self.rot_pos_emb(thw).unsqueeze(0))
|
343 |
+
# rotary_pos_emb1 = torch.cat(rotary_pos_emb, dim=1).squeeze(0)
|
344 |
+
# grid_thws = torch.cat(grid_thws, dim = 0)
|
345 |
+
|
346 |
+
# new version of creating rotary position embedding
|
347 |
+
# grid_thws shapes like [batch_flatten_image_num, 3]
|
348 |
+
# grid_thws = torch.cat(grid_thws, dim = 0) # is conducted in the `encoder.py`
|
349 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thws, strides)
|
350 |
+
|
351 |
+
cu_seqlens = torch.repeat_interleave(grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]).cumsum(dim=0, dtype=torch.int32)
|
352 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
353 |
+
|
354 |
+
for blk in self.blocks:
|
355 |
+
if self.gradient_checkpointing and self.training:
|
356 |
+
hidden_states = self._gradient_checkpointing_func(
|
357 |
+
blk.__call__,
|
358 |
+
hidden_states,
|
359 |
+
cu_seqlens,
|
360 |
+
rotary_pos_emb
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
364 |
+
|
365 |
+
# if self.spatial_merge_size > 1:
|
366 |
+
# hidden_states = self.merger(hidden_states)
|
367 |
+
return hidden_states
|
videollama3/model/region_encoder.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
class MaskExtractor(nn.Module):
|
7 |
+
def __init__(self, config, mm_hidden_size, depth=2):
|
8 |
+
super(MaskExtractor, self).__init__()
|
9 |
+
self.mask_pooling = MaskPooling()
|
10 |
+
modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
|
11 |
+
for _ in range(1, depth):
|
12 |
+
modules.append(nn.GELU())
|
13 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
14 |
+
self.feat_linear = nn.Sequential(*modules)
|
15 |
+
|
16 |
+
def forward(self, feats, masks):
|
17 |
+
query_feats = []
|
18 |
+
|
19 |
+
if masks is None: #infer
|
20 |
+
return None
|
21 |
+
# masks = torch.zeros((1, 1, 336, 336)).to(feats.device).float()
|
22 |
+
|
23 |
+
num_imgs = len(masks)
|
24 |
+
region_token_nums = []
|
25 |
+
image_idx = 0
|
26 |
+
for idx in range(num_imgs):
|
27 |
+
if masks[idx]==None:
|
28 |
+
continue
|
29 |
+
for mask_idx in range(len(masks[idx])):
|
30 |
+
mask = masks[idx][mask_idx].unsqueeze(0).unsqueeze(0).float()
|
31 |
+
if len(mask[0])==0:
|
32 |
+
print('mask error')
|
33 |
+
mask = torch.zeros((1, 1, 336, 336)).to(feats.device).float()
|
34 |
+
|
35 |
+
feat = feats[image_idx].unsqueeze(0)
|
36 |
+
image_idx+=1
|
37 |
+
|
38 |
+
# h, w = feat.shape[1:3]
|
39 |
+
feat = feat.permute(0,3,1,2)
|
40 |
+
|
41 |
+
raw_dtype = feat.dtype
|
42 |
+
feat = feat.to(mask.dtype)
|
43 |
+
|
44 |
+
mask_feat_raw = self.mask_pooling(feat, mask) # [n, 1024]
|
45 |
+
|
46 |
+
query_feats.append(mask_feat_raw)
|
47 |
+
if len(query_feats)==0:
|
48 |
+
return None
|
49 |
+
mask_feats = torch.cat(query_feats, dim=0)
|
50 |
+
mask_feats = mask_feats.to(feats[0].dtype)
|
51 |
+
mask_feats_linear = self.feat_linear(mask_feats)
|
52 |
+
return mask_feats_linear
|
53 |
+
|
54 |
+
def kmeans_fast(tokens, num_clusters=10, num_iterations=20):
|
55 |
+
# tokens: 输入的token数据,shape为[n, d]
|
56 |
+
# num_clusters: 压缩后的组数
|
57 |
+
# num_iterations: K-means算法的迭代次数
|
58 |
+
|
59 |
+
# 初始化中心点
|
60 |
+
n, d = tokens.shape
|
61 |
+
centroids = tokens[torch.randperm(n)[:num_clusters]]
|
62 |
+
|
63 |
+
for _ in range(num_iterations):
|
64 |
+
# 扩展tokens和centroids维度以计算距离,避免显式循环
|
65 |
+
tokens_expand = tokens.unsqueeze(1) # [n, 1, d]
|
66 |
+
centroids_expand = centroids.unsqueeze(0) # [1, num_clusters, d]
|
67 |
+
|
68 |
+
# 计算每个token到各个中心点的距离
|
69 |
+
distances = torch.sum((tokens_expand - centroids_expand) ** 2, dim=2) # [n, num_clusters]
|
70 |
+
|
71 |
+
# 找到每个token最近的中心点
|
72 |
+
labels = torch.argmin(distances, dim=1) # [n]
|
73 |
+
|
74 |
+
# 计算新的中心点
|
75 |
+
new_centroids = torch.stack([tokens[labels == i].mean(dim=0) if tokens[labels == i].size(0) > 0 else centroids[i] for i in range(num_clusters)])
|
76 |
+
|
77 |
+
# 检查是否收敛
|
78 |
+
if torch.allclose(centroids, new_centroids, atol=1e-6):
|
79 |
+
break
|
80 |
+
|
81 |
+
centroids = new_centroids
|
82 |
+
|
83 |
+
return centroids
|
84 |
+
|
85 |
+
class MaskPooling(nn.Module):
|
86 |
+
def __init__(self):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
def forward(self, x, mask):
|
90 |
+
|
91 |
+
if not x.shape[-2:] == mask.shape[-2:]:
|
92 |
+
# reshape mask to x
|
93 |
+
x = F.interpolate(x, size=mask.shape[-2:], mode='bilinear', align_corners=False)
|
94 |
+
# mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
|
95 |
+
if not x.device == mask.device:
|
96 |
+
mask = mask.to(x.device)
|
97 |
+
# b, c, h ,w = x.shape
|
98 |
+
# b, q, h, w = mask.shape
|
99 |
+
mask = (mask > 0).to(mask.dtype)
|
100 |
+
mask = mask.permute(1,0,2,3)
|
101 |
+
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
|
102 |
+
|
103 |
+
mask_emb = x * mask
|
104 |
+
mask = torch.any(mask_emb != 0, dim=(0, 1))
|
105 |
+
mask_emb = mask_emb[:,:, mask]
|
106 |
+
mask_embedding = mask_emb[0].permute(1,0)
|
107 |
+
|
108 |
+
if len(mask_embedding)>10: #FIXME
|
109 |
+
mask_embedding = kmeans_fast(mask_embedding)
|
110 |
+
|
111 |
+
return mask_embedding
|
112 |
+
|
113 |
+
|
114 |
+
def build_region_encoder(config, mm_hidden_size):
|
115 |
+
|
116 |
+
return MaskExtractor(config, mm_hidden_size)
|
117 |
+
|
videollama3/model/videollama3_arch.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
2 |
+
# Copyright 2023 Haotian Liu
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import os
|
17 |
+
import math
|
18 |
+
from abc import ABC, abstractmethod
|
19 |
+
|
20 |
+
import einops
|
21 |
+
import torch
|
22 |
+
import torch.distributed as dist
|
23 |
+
import torch.nn as nn
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
from ..constants import IGNORE_INDEX, MODAL_INDEX_MAP, NUM_FRAMES
|
27 |
+
from .encoder import build_vision_encoder
|
28 |
+
from .projector import build_vision_projector, load_mm_projector
|
29 |
+
from .region_encoder import build_region_encoder
|
30 |
+
from ..mm_utils import reshape_images_to_raw_grid
|
31 |
+
|
32 |
+
def spatial_downsampling(features, grid_thws, strides):
|
33 |
+
n, c = features.shape
|
34 |
+
|
35 |
+
flatten_grid_thws = torch.cat([grid_thw for batch_grid_thws in grid_thws for grid_thw in batch_grid_thws])
|
36 |
+
split_sizes = [grid_thw.prod() for grid_thw in flatten_grid_thws]
|
37 |
+
features = torch.split(features, split_sizes)
|
38 |
+
flatten_strides = [stride for batch_strides in strides for stride in batch_strides]
|
39 |
+
|
40 |
+
new_features = []
|
41 |
+
for feature, grid_thw, stride in zip(features, flatten_grid_thws, flatten_strides):
|
42 |
+
# NOTE: adapted for reshape in image processor
|
43 |
+
feature = feature.view(grid_thw[0], grid_thw[1] // stride, grid_thw[2] // stride, stride, stride, c).permute(0, 1, 3, 2, 4, 5)
|
44 |
+
feature = feature.reshape(grid_thw[0], grid_thw[1], grid_thw[2], c).permute(0, 3, 1, 2)
|
45 |
+
# NOTE: previous version model is align_corners=True
|
46 |
+
new_feature = torch.nn.functional.interpolate(feature, (math.ceil(grid_thw[1] / stride), math.ceil(grid_thw[2] / stride)), mode='bilinear')
|
47 |
+
# new_feature = nn.functional.avg_pool2d(feature, stride)
|
48 |
+
# new_feature = nn.functional.max_pool2d(feature, stride)
|
49 |
+
new_features.append(new_feature.permute(0, 2, 3, 1).view(-1, c))
|
50 |
+
new_features = torch.cat(new_features)
|
51 |
+
|
52 |
+
return new_features
|
53 |
+
|
54 |
+
|
55 |
+
class Videollama3MetaModel:
|
56 |
+
|
57 |
+
def __init__(self, config):
|
58 |
+
super(Videollama3MetaModel, self).__init__(config)
|
59 |
+
|
60 |
+
if hasattr(config, "vision_encoder") or hasattr(config, "mm_vision_encoder"):
|
61 |
+
self.vision_encoder = build_vision_encoder(config, delay_load=False)
|
62 |
+
self.mm_projector = build_vision_projector(config)
|
63 |
+
self.region_encoder = build_region_encoder(config, self.vision_encoder.hidden_size)
|
64 |
+
|
65 |
+
def get_vision_encoder(self):
|
66 |
+
vision_encoder = getattr(self, 'vision_encoder', None)
|
67 |
+
if type(vision_encoder) is list:
|
68 |
+
vision_encoder = vision_encoder[0]
|
69 |
+
return vision_encoder
|
70 |
+
|
71 |
+
def get_mm_projector(self):
|
72 |
+
return self.mm_projector
|
73 |
+
|
74 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
75 |
+
vision_encoder = model_args.vision_encoder
|
76 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
77 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
78 |
+
pretrain_mm_projector = model_args.pretrain_mm_projector
|
79 |
+
|
80 |
+
self.config.mm_vision_encoder = vision_encoder
|
81 |
+
|
82 |
+
if self.get_vision_encoder() is None:
|
83 |
+
vision_encoder = build_vision_encoder(model_args)
|
84 |
+
|
85 |
+
if fsdp is not None and len(fsdp) > 0:
|
86 |
+
self.vision_encoder = [vision_encoder]
|
87 |
+
else:
|
88 |
+
self.vision_encoder = vision_encoder
|
89 |
+
else:
|
90 |
+
if fsdp is not None and len(fsdp) > 0:
|
91 |
+
vision_encoder = self.vision_encoder[0]
|
92 |
+
else:
|
93 |
+
vision_encoder = self.vision_encoder
|
94 |
+
# NOTE: only compatible with delay_load encoder
|
95 |
+
# vision_encoder.load_model(vision_encoder.cfg_only)
|
96 |
+
|
97 |
+
self.config.use_mm_proj = True
|
98 |
+
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
|
99 |
+
self.config.mm_hidden_size = vision_encoder.hidden_size
|
100 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
101 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
102 |
+
|
103 |
+
if getattr(self, 'mm_projector', None) is None:
|
104 |
+
self.mm_projector = build_vision_projector(self.config)
|
105 |
+
else:
|
106 |
+
# In case it is frozen by LoRA
|
107 |
+
for p in self.mm_projector.parameters():
|
108 |
+
p.requires_grad = True
|
109 |
+
|
110 |
+
if pretrain_mm_projector is not None:
|
111 |
+
if os.path.exists(pretrain_mm_projector):
|
112 |
+
is_local = True
|
113 |
+
if os.path.isdir(pretrain_mm_projector):
|
114 |
+
mm_projector_weights = load_mm_projector(pretrain_mm_projector)
|
115 |
+
else:
|
116 |
+
mm_projector_weights = torch.load(pretrain_mm_projector, map_location='cpu')
|
117 |
+
else:
|
118 |
+
# Support loading projector weights from remote HuggingFace model hub
|
119 |
+
is_local = False
|
120 |
+
pretrain_mm_projector = pretrain_mm_projector.replace('mm_projector.bin', '')
|
121 |
+
pretrain_mm_projector = pretrain_mm_projector.strip('/').strip('\\').strip()
|
122 |
+
mm_projector_weights = load_mm_projector(pretrain_mm_projector)
|
123 |
+
|
124 |
+
def get_w(weights, keyword):
|
125 |
+
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
126 |
+
|
127 |
+
# self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
|
128 |
+
# set strict=False to avoid missing key error regarding bert.embeddings.position_ids
|
129 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
|
130 |
+
|
131 |
+
|
132 |
+
class Videollama3MetaForCausalLM(ABC):
|
133 |
+
|
134 |
+
@abstractmethod
|
135 |
+
def get_model(self):
|
136 |
+
pass
|
137 |
+
|
138 |
+
def num_frames(self):
|
139 |
+
if hasattr(self.config, 'num_frames'):
|
140 |
+
return self.config.num_frames
|
141 |
+
else:
|
142 |
+
return NUM_FRAMES
|
143 |
+
|
144 |
+
def spatial_merge_size(self):
|
145 |
+
if hasattr(self.config, 'spatial_merge_size'):
|
146 |
+
return self.config.spatial_merge_size
|
147 |
+
else:
|
148 |
+
return 1
|
149 |
+
|
150 |
+
def get_vision_encoder(self):
|
151 |
+
return self.get_model().get_vision_encoder()
|
152 |
+
|
153 |
+
def get_mm_projector(self):
|
154 |
+
return self.get_model().get_mm_projector()
|
155 |
+
|
156 |
+
def encode_images(self,images, grid_thws, strides):
|
157 |
+
"""
|
158 |
+
images shape [b c h w]
|
159 |
+
"""
|
160 |
+
images_features = self.get_model().get_vision_encoder()(images, grid_thws=grid_thws, strides=strides)
|
161 |
+
# images_features = spatial_downsampling(images_features, grid_thws, stride=self.config.spatial_merge_size)
|
162 |
+
mm_features = spatial_downsampling(images_features, grid_thws, strides=strides)
|
163 |
+
images_features = self.get_model().mm_projector(mm_features)
|
164 |
+
|
165 |
+
return images_features
|
166 |
+
|
167 |
+
def prepare_inputs_labels_for_multimodal(
|
168 |
+
self, input_ids, attention_mask, past_key_values, labels, images, position_ids=None, masks=None, additional_images = None,
|
169 |
+
):
|
170 |
+
if self.config.use_token_compression:
|
171 |
+
return self.prepare_inputs_labels_for_multimodal_with_compression(input_ids, attention_mask, past_key_values, labels, images, position_ids, masks, additional_images)
|
172 |
+
|
173 |
+
# # images shape (modal, tensor, flag)
|
174 |
+
# vision_encoder = self.get_vision_encoder()
|
175 |
+
# # NOTE: text-only situation
|
176 |
+
# if vision_encoder is None or images is None or input_ids.shape[1] == 1:
|
177 |
+
# return input_ids, attention_mask, past_key_values, None, labels, position_ids
|
178 |
+
|
179 |
+
# # NOTE: Equvialent to the following code:
|
180 |
+
# # images_tensor = [image for modal, image, image_flag, grid_thw in images]
|
181 |
+
# # images_flag = [image_flag for modal, image, image_flag, grid_thw in images]
|
182 |
+
# # grid_thws = [grid_thw for modal, image, image_flag, grid_thw in images]
|
183 |
+
# modals, images, grid_thws = zip(*images)
|
184 |
+
|
185 |
+
# images_flag = []
|
186 |
+
# strides = []
|
187 |
+
# for modal, grid_thw in zip(modals, grid_thws):
|
188 |
+
# grid_thw = torch.cat(grid_thw)
|
189 |
+
# stride = self.config.spatial_merge_size if modal == "video" else 1
|
190 |
+
# num_patches = grid_thw.prod(dim=-1).sum().div(stride**2).long()
|
191 |
+
# image_flag = torch.full((num_patches, ), 0 if modal == 'text' else 1)
|
192 |
+
# images_flag.append(image_flag)
|
193 |
+
# strides.append([stride] * grid_thw.size(0))
|
194 |
+
# images_flag_tensor = torch.cat(images_flag)
|
195 |
+
|
196 |
+
# mm_features = self.encode_images(images, grid_thws, strides)
|
197 |
+
# mm_features = mm_features[images_flag_tensor.to(mm_features.device) == 1].to(input_ids.device)
|
198 |
+
|
199 |
+
# additional_images_list = []
|
200 |
+
# additional_images_thw = []
|
201 |
+
# additional_images_strides = []
|
202 |
+
|
203 |
+
# for i in range(len(additional_images)):
|
204 |
+
# additional_images_list.append(torch.from_numpy(np.array(additional_images[0][0])).to(mm_features.dtype).to(mm_features.device))
|
205 |
+
# additional_images_thw.append(torch.tensor(additional_images[0][1][0]).to(mm_features.device))
|
206 |
+
# additional_images_strides.append([1]*len(additional_images[0][1][0]))
|
207 |
+
|
208 |
+
|
209 |
+
# image_selected = (input_ids == self.config.image_token_index)
|
210 |
+
# audio_selected = (input_ids == MODAL_INDEX_MAP['<audio>'])
|
211 |
+
# input_ids[image_selected] = 0
|
212 |
+
# input_ids[audio_selected] = 0
|
213 |
+
|
214 |
+
# input_embeds = self.get_model().embed_tokens(input_ids).clone()
|
215 |
+
|
216 |
+
# B, N, C = input_embeds.shape
|
217 |
+
# input_embeds = input_embeds.reshape(B * N, C).to(input_ids.device)
|
218 |
+
# image_selected = image_selected.reshape(B * N)
|
219 |
+
# audio_selected = audio_selected.reshape(B * N)
|
220 |
+
|
221 |
+
# input_embeds[image_selected] = input_embeds[image_selected] * 0.0 + mm_features.reshape(-1, C)
|
222 |
+
|
223 |
+
# # replace region token
|
224 |
+
# mask_selected = (input_ids == self.config.region_token_index)
|
225 |
+
# if mask_selected.sum()>0:
|
226 |
+
|
227 |
+
# additional_images_features = self.get_model().get_vision_encoder()(additional_images_list, grid_thws=[additional_images_thw], strides=additional_images_strides)
|
228 |
+
# reshaped_features = reshape_images_to_raw_grid(additional_images_features, additional_images_thw)
|
229 |
+
# mask_additional_image_features = []
|
230 |
+
# for idx in mask_ids:
|
231 |
+
# mask_additional_image_features.append(reshaped_features[idx])
|
232 |
+
# mask_feats = self.model.region_encoder(mask_additional_image_features, masks)
|
233 |
+
# input_embeds[mask_selected] = input_embeds[mask_selected]*0.0 + mask_feats
|
234 |
+
|
235 |
+
|
236 |
+
# input_embeds = input_embeds.reshape(B, N, C)
|
237 |
+
|
238 |
+
# return None, attention_mask, past_key_values, input_embeds, labels, position_ids
|
239 |
+
|
240 |
+
def prepare_inputs_labels_for_multimodal_with_compression(
|
241 |
+
self, input_ids, attention_mask, past_key_values, labels, images, position_ids=None, masks=None, additional_images = None,
|
242 |
+
):
|
243 |
+
# images shape (modal, tensor, flag)
|
244 |
+
vision_encoder = self.get_vision_encoder()
|
245 |
+
# NOTE: text-only situation
|
246 |
+
if vision_encoder is None or images is None or input_ids.shape[1] == 1:
|
247 |
+
return input_ids, attention_mask, past_key_values, None, labels, position_ids
|
248 |
+
|
249 |
+
# NOTE: Equvialent to the following code:
|
250 |
+
# images_tensor = [image for modal, image, image_flag, grid_thw in images]
|
251 |
+
# images_flag = [image_flag for modal, image, image_flag, grid_thw in images]
|
252 |
+
# grid_thws = [grid_thw for modal, image, image_flag, grid_thw in images]
|
253 |
+
modals, images, grid_thws = zip(*images)
|
254 |
+
|
255 |
+
images_flag = []
|
256 |
+
visual_masks = []
|
257 |
+
strides = []
|
258 |
+
visual_trunc_masks = []
|
259 |
+
|
260 |
+
for modal, image, grid_thw in zip(modals, images, grid_thws):
|
261 |
+
grid_thw = torch.cat(grid_thw)
|
262 |
+
stride = self.config.spatial_merge_size if modal == "video" else 1
|
263 |
+
num_patches = grid_thw.prod(dim=-1).sum().div(stride**2).long()
|
264 |
+
image_flag = torch.full((num_patches, ), 0 if modal == 'text' else 1)
|
265 |
+
images_flag.append(image_flag)
|
266 |
+
strides.append([stride] * grid_thw.size(0))
|
267 |
+
|
268 |
+
if modal == "image" or (modal == "video" and len(image) == 1):
|
269 |
+
visual_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device))
|
270 |
+
visual_trunc_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device))
|
271 |
+
|
272 |
+
elif modal == "video":
|
273 |
+
# NOTE: video frame compressor
|
274 |
+
n, h, w = len(image), grid_thw[0][1], grid_thw[0][2]
|
275 |
+
image = torch.stack(image, dim=0).view(n, (h // stride) * (w // stride), -1)
|
276 |
+
|
277 |
+
threshold = 0.1
|
278 |
+
min_tokens = 1
|
279 |
+
pixel_diff = image[1:] - image[:-1]
|
280 |
+
pixel_diff = torch.abs(pixel_diff).mean(dim=-1) * 255
|
281 |
+
pixel_diff = torch.cat([torch.full_like(pixel_diff[0:1], threshold + 1), pixel_diff], dim=0)
|
282 |
+
# if dist.get_rank() == 0:
|
283 |
+
# print(pixel_diff.shape, image.shape)
|
284 |
+
mask = pixel_diff > threshold
|
285 |
+
padding_ids = torch.nonzero(mask.sum(dim=1) < min_tokens)[:, 0]
|
286 |
+
# mask[padding_ids, torch.randperm(min_tokens)] = 1
|
287 |
+
mask[padding_ids, :min_tokens] = 1
|
288 |
+
visual_masks.append(mask.flatten())
|
289 |
+
visual_trunc_masks.append(torch.ones((num_patches,), dtype=torch.bool, device=input_ids.device))
|
290 |
+
|
291 |
+
elif modal == "text":
|
292 |
+
visual_trunc_masks.append(torch.ones((0,), dtype=torch.bool, device=input_ids.device))
|
293 |
+
|
294 |
+
images_flag_tensor = torch.cat(images_flag)
|
295 |
+
mm_features = self.encode_images(images, grid_thws, strides)
|
296 |
+
mm_features = mm_features[images_flag_tensor.to(mm_features.device) == 1]
|
297 |
+
|
298 |
+
additional_images_list = []
|
299 |
+
additional_images_thw = []
|
300 |
+
additional_images_strides = []
|
301 |
+
|
302 |
+
if additional_images is not None: #and additional_images[0] is not None
|
303 |
+
for i in range(len(additional_images)):
|
304 |
+
for img_idx in range(len(additional_images[i][0])):
|
305 |
+
additional_images_list.append([torch.from_numpy(np.array(additional_images[i][0][img_idx])).to(mm_features.dtype).to(mm_features.device)])
|
306 |
+
additional_images_thw.append([torch.tensor(np.array(additional_images[i][1][img_idx])).to(mm_features.device)])
|
307 |
+
additional_images_strides.append([1]*len(additional_images[i][1][img_idx]))
|
308 |
+
# additional_images_list.append(additional_images[i][0])
|
309 |
+
# additional_images_thw.append(additional_images[i][1])
|
310 |
+
# additional_images_strides.append([1]*len(additional_images[i][1]))
|
311 |
+
|
312 |
+
# import pdb
|
313 |
+
# pdb.set_trace()
|
314 |
+
|
315 |
+
B, N = input_ids.shape
|
316 |
+
C = mm_features.shape[-1]
|
317 |
+
|
318 |
+
assert B == 1, "Only support batch flattening for now"
|
319 |
+
input_ids = input_ids.view(B * N)
|
320 |
+
image_selected = (input_ids == self.config.image_token_index)
|
321 |
+
audio_selected = (input_ids == MODAL_INDEX_MAP['<audio>'])
|
322 |
+
|
323 |
+
if len(visual_masks) > 0:
|
324 |
+
# if dist.get_rank() == 0:
|
325 |
+
# print(grid_thws, [x.shape for x in visual_masks])
|
326 |
+
visual_masks = torch.cat(visual_masks)
|
327 |
+
# print((visual_masks == 1).sum(), (visual_masks == 0).sum())
|
328 |
+
|
329 |
+
mm_features = mm_features[visual_masks]
|
330 |
+
# text_masks = torch.zeros_like(input_ids, dtype=torch.bool)
|
331 |
+
# text_masks[~image_selected] = True
|
332 |
+
text_masks = torch.logical_not(image_selected)
|
333 |
+
|
334 |
+
try:
|
335 |
+
text_masks[image_selected] = visual_masks
|
336 |
+
except Exception as e:
|
337 |
+
assert position_ids is not None, "Position ids must be provided when shapes mismatch"
|
338 |
+
print(
|
339 |
+
f'warning: {e}, text_masks[image_selected].shape={text_masks[image_selected].shape},',
|
340 |
+
f'visual_masks.shape={visual_masks.shape}'
|
341 |
+
)
|
342 |
+
|
343 |
+
seq_end_indices = torch.nonzero(position_ids.view(B * N) == 0)[:, 0]
|
344 |
+
seq_end_indices = seq_end_indices[seq_end_indices > 0]
|
345 |
+
seq_end_indices = seq_end_indices.tolist()+ [len(input_ids)]
|
346 |
+
seq_start_indices = [0] + seq_end_indices[:-1]
|
347 |
+
num_visual_tokens = [
|
348 |
+
input_ids[start:end].eq(self.config.image_token_index).sum()
|
349 |
+
for start, end in zip(seq_start_indices, seq_end_indices)
|
350 |
+
]
|
351 |
+
|
352 |
+
for n, mask in zip(num_visual_tokens, visual_trunc_masks):
|
353 |
+
if len(mask) > 0:
|
354 |
+
mask[n:] = False
|
355 |
+
visual_trunc_masks = torch.cat(visual_trunc_masks)
|
356 |
+
|
357 |
+
text_masks[image_selected] = visual_masks[visual_trunc_masks]
|
358 |
+
mm_features = mm_features[visual_trunc_masks[visual_masks]]
|
359 |
+
|
360 |
+
else:
|
361 |
+
text_masks = torch.ones_like(input_ids, dtype=torch.bool)
|
362 |
+
|
363 |
+
input_ids = input_ids[text_masks]
|
364 |
+
if attention_mask is not None:
|
365 |
+
attention_mask = attention_mask.view(B * N)[text_masks].reshape(1, -1)
|
366 |
+
if labels is not None:
|
367 |
+
labels = labels.view(B * N)[text_masks].reshape(1, -1)
|
368 |
+
if position_ids is not None:
|
369 |
+
position_ids = position_ids.view(B * N)[text_masks]
|
370 |
+
pos_start = [0] + torch.nonzero(position_ids == 0)[:, 0].tolist()
|
371 |
+
pos_end = pos_start[1:] + [len(input_ids)]
|
372 |
+
position_ids = torch.cat([torch.arange(end - start, device=input_ids.device) for start, end in zip(pos_start, pos_end)])
|
373 |
+
position_ids = position_ids.reshape(1, -1)
|
374 |
+
|
375 |
+
image_selected = (input_ids == self.config.image_token_index)
|
376 |
+
audio_selected = (input_ids == MODAL_INDEX_MAP['<audio>'])
|
377 |
+
input_ids[image_selected] = 0
|
378 |
+
input_ids[audio_selected] = 0
|
379 |
+
|
380 |
+
input_embeds = self.get_model().embed_tokens(input_ids).clone()
|
381 |
+
|
382 |
+
input_embeds[image_selected] = input_embeds[image_selected] * 0.0 + mm_features.reshape(-1, C)
|
383 |
+
|
384 |
+
# replace region token
|
385 |
+
mask_selected = (input_ids == self.config.region_token_index)
|
386 |
+
|
387 |
+
try:
|
388 |
+
if mask_selected.sum()>0:
|
389 |
+
# try:
|
390 |
+
# patches = np.ascontiguousarray(additional_images_list[0][0])
|
391 |
+
# grid_h = additional_images_thw[0][0][0][1]
|
392 |
+
# grid_w = additional_images_thw[0][0][0][2]
|
393 |
+
# patches = patches.reshape(grid_h ,grid_w, 3, 14, 14)
|
394 |
+
# from matplotlib import pyplot as plt
|
395 |
+
# plt.imshow(patches[:,:,:,0,0])
|
396 |
+
# plt.savefig('7.png')
|
397 |
+
# import pdb
|
398 |
+
# pdb.set_trace()
|
399 |
+
# patches = patches.transpose(2, 0, 3, 1, 4)
|
400 |
+
# reconstructed_image = patches.reshape(3, grid_h*14, grid_w*14).transpose(1, 2, 0)
|
401 |
+
# from matplotlib import pyplot as plt
|
402 |
+
# plt.imshow(reconstructed_image)
|
403 |
+
# plt.savefig('7.png')
|
404 |
+
# import pdb
|
405 |
+
# pdb.set_trace()
|
406 |
+
additional_images_features = self.get_model().get_vision_encoder()(additional_images_list, grid_thws=additional_images_thw, strides=additional_images_strides)
|
407 |
+
reshaped_features = reshape_images_to_raw_grid(additional_images_features, additional_images_thw)
|
408 |
+
|
409 |
+
# mask_additional_image_features = []
|
410 |
+
# for idx in mask_ids:
|
411 |
+
# mask_additional_image_features.append(reshaped_features[idx])
|
412 |
+
|
413 |
+
mask_feats = self.model.region_encoder(reshaped_features, masks)
|
414 |
+
|
415 |
+
input_embeds[mask_selected] = input_embeds[mask_selected]*0.0 + mask_feats
|
416 |
+
# except: #FIXME
|
417 |
+
# print('additional_images_list is empty...')
|
418 |
+
except Exception as exp:
|
419 |
+
print('error: ', exp)
|
420 |
+
new_input_embeds = input_embeds.reshape(1, -1, C)
|
421 |
+
|
422 |
+
return None, attention_mask, past_key_values, new_input_embeds, labels, position_ids
|
videollama3/model/videollama3_qwen2.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from: https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
2 |
+
# Copyright 2023 Haotian Liu
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
from transformers import (AutoConfig, AutoModelForCausalLM, Qwen2Config,
|
22 |
+
Qwen2ForCausalLM, Qwen2Model)
|
23 |
+
from transformers.generation.utils import GenerateOutput
|
24 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
25 |
+
|
26 |
+
from .videollama3_arch import Videollama3MetaForCausalLM, Videollama3MetaModel
|
27 |
+
|
28 |
+
|
29 |
+
class Videollama3Qwen2Config(Qwen2Config):
|
30 |
+
model_type = "videollama3_qwen2"
|
31 |
+
|
32 |
+
def __init__(self, **kwargs):
|
33 |
+
super().__init__(**kwargs)
|
34 |
+
self.model_type = "videollama3_qwen2"
|
35 |
+
|
36 |
+
|
37 |
+
class Videollama3Qwen2Model(Videollama3MetaModel, Qwen2Model):
|
38 |
+
config_class = Videollama3Qwen2Config
|
39 |
+
|
40 |
+
def __init__(self, config: Videollama3Qwen2Config):
|
41 |
+
super(Videollama3Qwen2Model, self).__init__(config)
|
42 |
+
|
43 |
+
|
44 |
+
class Videollama3Qwen2ForCausalLM(Qwen2ForCausalLM, Videollama3MetaForCausalLM):
|
45 |
+
config_class = Videollama3Qwen2Config
|
46 |
+
|
47 |
+
def __init__(self, config, **kwargs):
|
48 |
+
super(Qwen2ForCausalLM, self).__init__(config)
|
49 |
+
self.model = Videollama3Qwen2Model(config)
|
50 |
+
# self.pretraining_tp = config.pretraining_tp
|
51 |
+
self.vocab_size = config.vocab_size
|
52 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
53 |
+
|
54 |
+
# Initialize weights and apply final processing
|
55 |
+
self.post_init()
|
56 |
+
|
57 |
+
def get_model(self):
|
58 |
+
return self.model
|
59 |
+
|
60 |
+
def forward(
|
61 |
+
self,
|
62 |
+
input_ids: torch.LongTensor = None,
|
63 |
+
attention_mask: Optional[torch.Tensor] = None,
|
64 |
+
position_ids: Optional[torch.LongTensor] = None,
|
65 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
66 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
67 |
+
labels: Optional[torch.LongTensor] = None,
|
68 |
+
use_cache: Optional[bool] = None,
|
69 |
+
output_attentions: Optional[bool] = None,
|
70 |
+
output_hidden_states: Optional[bool] = None,
|
71 |
+
images: Optional[torch.FloatTensor] = None,
|
72 |
+
return_dict: Optional[bool] = None,
|
73 |
+
cache_position: Optional[int] = None,
|
74 |
+
masks: Optional[List[torch.LongTensor]] = None,
|
75 |
+
additional_images = None,
|
76 |
+
**kwargs
|
77 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
78 |
+
if inputs_embeds is None:
|
79 |
+
(
|
80 |
+
input_ids,
|
81 |
+
attention_mask,
|
82 |
+
past_key_values,
|
83 |
+
inputs_embeds,
|
84 |
+
labels, position_ids
|
85 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
86 |
+
input_ids,
|
87 |
+
attention_mask,
|
88 |
+
past_key_values,
|
89 |
+
labels,
|
90 |
+
images,
|
91 |
+
position_ids,
|
92 |
+
masks,
|
93 |
+
additional_images
|
94 |
+
)
|
95 |
+
|
96 |
+
return super().forward(
|
97 |
+
input_ids=input_ids,
|
98 |
+
attention_mask=attention_mask,
|
99 |
+
position_ids=position_ids,
|
100 |
+
past_key_values=past_key_values,
|
101 |
+
inputs_embeds=inputs_embeds,
|
102 |
+
labels=labels,
|
103 |
+
use_cache=use_cache,
|
104 |
+
output_attentions=output_attentions,
|
105 |
+
output_hidden_states=output_hidden_states,
|
106 |
+
return_dict=return_dict,
|
107 |
+
cache_position=cache_position,
|
108 |
+
)
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def generate(
|
112 |
+
self,
|
113 |
+
inputs: Optional[torch.Tensor] = None,
|
114 |
+
images: Optional[torch.Tensor] = None,
|
115 |
+
**kwargs,
|
116 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
117 |
+
position_ids = kwargs.pop("position_ids", None)
|
118 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
119 |
+
additional_images = kwargs.pop("additional_images", None)
|
120 |
+
masks = kwargs.pop("masks", None)
|
121 |
+
if "inputs_embeds" in kwargs:
|
122 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
123 |
+
|
124 |
+
if images is not None:
|
125 |
+
(
|
126 |
+
input_ids,
|
127 |
+
attention_mask,
|
128 |
+
past_key_values,
|
129 |
+
inputs_embeds,
|
130 |
+
_,
|
131 |
+
position_ids
|
132 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
133 |
+
input_ids=inputs,
|
134 |
+
attention_mask=attention_mask,
|
135 |
+
past_key_values=None,
|
136 |
+
labels=None,
|
137 |
+
images=images,
|
138 |
+
position_ids=position_ids,
|
139 |
+
additional_images=additional_images,
|
140 |
+
masks=masks,
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
144 |
+
|
145 |
+
return super().generate(
|
146 |
+
position_ids=position_ids,
|
147 |
+
attention_mask=attention_mask,
|
148 |
+
inputs_embeds=inputs_embeds,
|
149 |
+
**kwargs
|
150 |
+
)
|
151 |
+
|
152 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
153 |
+
images = kwargs.pop("images", None)
|
154 |
+
_inputs = super().prepare_inputs_for_generation(
|
155 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
156 |
+
)
|
157 |
+
if images is not None:
|
158 |
+
_inputs['images'] = images
|
159 |
+
return _inputs
|
160 |
+
|
161 |
+
|
162 |
+
AutoConfig.register("videollama3_qwen2", Videollama3Qwen2Config)
|
163 |
+
AutoModelForCausalLM.register(Videollama3Qwen2Config, Videollama3Qwen2ForCausalLM)
|
videollama3/train.py
ADDED
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
2 |
+
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
|
3 |
+
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
|
4 |
+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import math
|
19 |
+
import copy
|
20 |
+
import json
|
21 |
+
import os
|
22 |
+
import pathlib
|
23 |
+
import random
|
24 |
+
import re
|
25 |
+
import sys
|
26 |
+
import warnings
|
27 |
+
import traceback
|
28 |
+
from packaging import version
|
29 |
+
from dataclasses import dataclass, field
|
30 |
+
from typing import Dict, List, Optional, Sequence
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
# torch-related packages
|
34 |
+
# NOTE: torch must be imported before transformers. Otherwise, `Segmentation fault (core dumped)` will occur.
|
35 |
+
import torch
|
36 |
+
import transformers
|
37 |
+
from packaging import version
|
38 |
+
from datasets import load_dataset, concatenate_datasets
|
39 |
+
from torch.utils.data import Dataset
|
40 |
+
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
41 |
+
|
42 |
+
sys.path.append('./')
|
43 |
+
|
44 |
+
from videollama3.constants import (IGNORE_INDEX, MODAL_INDEX_MAP,
|
45 |
+
NUM_FRAMES, DEFAULT_IMAGE_TOKEN, STREAM_MAX_FRAMES,
|
46 |
+
STREAM_DOWNSAMPLING, STREAM_FPS, STREAM_IMAGE_SIZE,
|
47 |
+
STREAM_START_TOKEN, STREAM_END_TOKEN, REGION_TOKEN)
|
48 |
+
from videollama3.mm_utils import (load_images, load_video,
|
49 |
+
tokenizer_multimodal_token, annToMask, resize_image_mask)
|
50 |
+
from videollama3.model import *
|
51 |
+
from videollama3.videollama3_trainer import (
|
52 |
+
VideoLLaMA3Trainer, find_all_linear_names, get_peft_state_maybe_zero_3,
|
53 |
+
get_peft_state_non_lora_maybe_zero_3, safe_save_model_for_hf_trainer)
|
54 |
+
from videollama3.model.processor import Videollama3Processor
|
55 |
+
|
56 |
+
# NOTE: fast tokenizer warning issue: https://github.com/huggingface/transformers/issues/5486
|
57 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
58 |
+
|
59 |
+
local_rank = None
|
60 |
+
|
61 |
+
|
62 |
+
def rank0_print(*args):
|
63 |
+
if local_rank == 0:
|
64 |
+
print(*args)
|
65 |
+
|
66 |
+
|
67 |
+
def set_seed(seed=42):
|
68 |
+
"""
|
69 |
+
Set the random seed for reproducible results.
|
70 |
+
|
71 |
+
:param seed: An integer value to be used as the random seed.
|
72 |
+
"""
|
73 |
+
torch.manual_seed(seed)
|
74 |
+
torch.cuda.manual_seed(seed)
|
75 |
+
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
|
76 |
+
torch.backends.cudnn.deterministic = True
|
77 |
+
torch.backends.cudnn.benchmark = False
|
78 |
+
|
79 |
+
|
80 |
+
def int_with_none(value):
|
81 |
+
if value == 'None':
|
82 |
+
return None
|
83 |
+
return int(value)
|
84 |
+
|
85 |
+
|
86 |
+
@dataclass
|
87 |
+
class ModelArguments:
|
88 |
+
# LLM Arguments
|
89 |
+
model_type: Optional[str] = field(default="videollama3", metadata={"help": "Model type selected in the list: " + ", ".join(VLLMs.keys())})
|
90 |
+
model_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.5")
|
91 |
+
version: Optional[str] = field(default="v1", metadata={"help": "Version of the conversation template."})
|
92 |
+
freeze_backbone: bool = field(default=False, metadata={"help": "Whether to freeze the LLM backbone."})
|
93 |
+
# Connector Arguments
|
94 |
+
mm_projector_type: Optional[str] = field(default='linear')
|
95 |
+
pretrain_mm_projector: Optional[str] = field(default=None)
|
96 |
+
# Vision tower Arguments
|
97 |
+
vision_encoder: Optional[str] = field(default=None)
|
98 |
+
mm_vision_select_layer: Optional[int] = field(default=-1)
|
99 |
+
mm_vision_select_feature: Optional[str] = field(default="patch")
|
100 |
+
mm_attn_implementation: Optional[str] = field(default="flash_attention_2")
|
101 |
+
# Token downsampling Arguments
|
102 |
+
spatial_merge_size: Optional[int] = field(default=1)
|
103 |
+
mm_max_length: Optional[int] = field(default=9477)
|
104 |
+
use_token_compression: Optional[bool] = field(default=False)
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class DataArguments:
|
109 |
+
# Path Arguments
|
110 |
+
data_path: List[str] = field(default=None, metadata={"help": "Path to the training data."})
|
111 |
+
# image_folder: Optional[str] = field(default=None)
|
112 |
+
# video_folder: Optional[str] = field(default=None)
|
113 |
+
data_folder: Optional[str] = field(default=None)
|
114 |
+
# Loading Arguments
|
115 |
+
is_multimodal: bool = False
|
116 |
+
fps: Optional[int] = field(default=None)
|
117 |
+
max_frames: Optional[int_with_none] = field(default=None)
|
118 |
+
# Preprocess Arguments
|
119 |
+
image_aspect_ratio: str = 'square'
|
120 |
+
use_batch_flattening: bool = field(default=True, metadata={"help": "Whether to flatten the in-batch sequences of variable lengths."})
|
121 |
+
dataset_cache_dir: Optional[str] = field(default=None)
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class TrainingArguments(transformers.TrainingArguments):
|
126 |
+
# shut auto processing (_remove_unused_columns) of transformers Trainer
|
127 |
+
remove_unused_columns: bool = field(default=False)
|
128 |
+
|
129 |
+
optim: str = field(default="adamw_torch")
|
130 |
+
# Training learning rate Arguments
|
131 |
+
vision_encoder_lr: Optional[float] = None
|
132 |
+
mm_projector_lr: Optional[float] = None
|
133 |
+
llm_lr: Optional[float] = None
|
134 |
+
region_encoder_lr: Optional[float] = None
|
135 |
+
# Training Data Arguments
|
136 |
+
group_by_modality_length: bool = field(default=False)
|
137 |
+
model_max_length: int = field(
|
138 |
+
default=512,
|
139 |
+
metadata={
|
140 |
+
"help":
|
141 |
+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
|
142 |
+
},
|
143 |
+
)
|
144 |
+
# Lora or Quant Arguments
|
145 |
+
double_quant: bool = field(
|
146 |
+
default=True,
|
147 |
+
metadata={"help": "Compress the quantization statistics through double quantization."}
|
148 |
+
)
|
149 |
+
quant_type: str = field(
|
150 |
+
default="nf4",
|
151 |
+
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
|
152 |
+
)
|
153 |
+
bits: int = field(
|
154 |
+
default=16,
|
155 |
+
metadata={"help": "How many bits to use."}
|
156 |
+
)
|
157 |
+
lora_enable: bool = False
|
158 |
+
lora_r: int = 64
|
159 |
+
lora_alpha: int = 16
|
160 |
+
lora_dropout: float = 0.05
|
161 |
+
lora_weight_path: str = ""
|
162 |
+
lora_bias: str = "none"
|
163 |
+
|
164 |
+
|
165 |
+
class LazySupervisedDataset(Dataset):
|
166 |
+
"""Dataset for supervised fine-tuning."""
|
167 |
+
|
168 |
+
def __init__(self, data_path: str, vlprocessor, data_args: DataArguments):
|
169 |
+
super(LazySupervisedDataset, self).__init__()
|
170 |
+
data_objs = []
|
171 |
+
# try:
|
172 |
+
# for data in data_path:
|
173 |
+
# # NOTE: load_dataset can process both json or jsonl files
|
174 |
+
# if data.endswith(".json") or data.endswith(".jsonl"):
|
175 |
+
# data_objs.append(load_dataset("json", data_files=data, cache_dir=data_args.dataset_cache_dir)["train"])
|
176 |
+
# else:
|
177 |
+
# raise Exception(f"Unsupported file format (<{data}>)!")
|
178 |
+
# list_data_dict = concatenate_datasets(data_objs)
|
179 |
+
# except:
|
180 |
+
traceback.print_exc()
|
181 |
+
# NOTE: compatible with the old version
|
182 |
+
list_data_dict = []
|
183 |
+
for data in data_path:
|
184 |
+
if data.endswith(".json"):
|
185 |
+
data = json.load(open(data, "r"))
|
186 |
+
for i in data:
|
187 |
+
i['id'] = len(list_data_dict)
|
188 |
+
list_data_dict.append(i)
|
189 |
+
elif data.endswith(".jsonl"):
|
190 |
+
with open(data, "r", encoding="utf-8") as fp:
|
191 |
+
for line in fp:
|
192 |
+
line = line.strip()
|
193 |
+
obj = json.loads(line)
|
194 |
+
obj["id"] = len(list_data_dict)
|
195 |
+
list_data_dict.append(obj)
|
196 |
+
else:
|
197 |
+
raise Exception(f"Unsupported file format (<{data}>)!!!")
|
198 |
+
|
199 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
200 |
+
self.vlprocessor = vlprocessor
|
201 |
+
self.list_data_dict = list_data_dict
|
202 |
+
self.data_args = data_args
|
203 |
+
|
204 |
+
def __len__(self):
|
205 |
+
return len(self.list_data_dict)
|
206 |
+
|
207 |
+
@property
|
208 |
+
def lengths(self):
|
209 |
+
length_list = []
|
210 |
+
for sample in self.list_data_dict:
|
211 |
+
img_tokens = 576 if 'image' in sample else 0
|
212 |
+
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
|
213 |
+
return length_list
|
214 |
+
|
215 |
+
@property
|
216 |
+
def modality_lengths(self):
|
217 |
+
length_list = []
|
218 |
+
for sample in self.list_data_dict:
|
219 |
+
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
|
220 |
+
cur_len = cur_len if 'image' in sample else -cur_len
|
221 |
+
length_list.append(cur_len)
|
222 |
+
return length_list
|
223 |
+
|
224 |
+
def _convert_normal(self, data_dict):
|
225 |
+
data_folder = self.data_args.data_folder
|
226 |
+
conversation = copy.deepcopy(data_dict["conversations"])
|
227 |
+
|
228 |
+
# data sanity check and repair
|
229 |
+
start_idx = 0
|
230 |
+
for sentence in conversation:
|
231 |
+
if sentence["from"] == "human" or sentence["from"] == "system":
|
232 |
+
break
|
233 |
+
start_idx += 1
|
234 |
+
if start_idx > 0:
|
235 |
+
warnings.warn(f"Find {start_idx} non-user sentences at the beginning of the conversation, remove them automatically!")
|
236 |
+
conversation = conversation[start_idx:]
|
237 |
+
assert len(conversation) > 1, f"Invalid conversation"
|
238 |
+
|
239 |
+
additional_frames = []
|
240 |
+
mask_ids = []
|
241 |
+
if 'image' in data_dict and data_dict['image'] is not None:
|
242 |
+
modal = 'image'
|
243 |
+
if all(not "<image>" in sentence["value"] for sentence in conversation):
|
244 |
+
warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!")
|
245 |
+
conversation[0]["value"] = "<image>" + conversation[0]["value"]
|
246 |
+
image_file = data_dict['image']
|
247 |
+
if isinstance(image_file, list):
|
248 |
+
image_file = [os.path.join(data_folder, f) for f in image_file]
|
249 |
+
else:
|
250 |
+
image_file = os.path.join(data_folder, image_file)
|
251 |
+
images = load_images(image_file)
|
252 |
+
|
253 |
+
masks = []
|
254 |
+
if 'masks' in data_dict and data_dict['masks'] is not None and len(data_dict['masks'])>0:
|
255 |
+
if 'height' in data_dict:
|
256 |
+
h = data_dict['height']
|
257 |
+
w = data_dict['width']
|
258 |
+
else:
|
259 |
+
h = None
|
260 |
+
w = None
|
261 |
+
for ann in data_dict['masks']:
|
262 |
+
mask = annToMask(ann, h, w)
|
263 |
+
masks.append(mask)
|
264 |
+
mask_ids.append(0)
|
265 |
+
masks = np.stack(masks, axis=0)
|
266 |
+
masks = torch.from_numpy(masks)
|
267 |
+
|
268 |
+
additional_frames = images.copy()
|
269 |
+
else:
|
270 |
+
masks = None
|
271 |
+
|
272 |
+
elif 'video' in data_dict and data_dict['video'] is not None:
|
273 |
+
modal = 'video'
|
274 |
+
if all(not "<video>" in sentence["value"] for sentence in conversation):
|
275 |
+
warnings.warn(f"Video tag not found in the conversation, add it automatically at the beginning!")
|
276 |
+
conversation[0]["value"] = "<video>" + conversation[0]["value"]
|
277 |
+
video_file = data_dict['video']
|
278 |
+
|
279 |
+
masks = []
|
280 |
+
frame_ids = []
|
281 |
+
if 'masks' in data_dict and data_dict['masks'] is not None:
|
282 |
+
if 'height' in data_dict:
|
283 |
+
h = data_dict['height']
|
284 |
+
w = data_dict['width']
|
285 |
+
else:
|
286 |
+
h = None
|
287 |
+
w = None
|
288 |
+
for ann in data_dict['masks']:
|
289 |
+
for k in ann.keys():
|
290 |
+
if int(k) not in frame_ids:
|
291 |
+
frame_ids.append(int(k))
|
292 |
+
mask_ids.append(frame_ids.index(int(k)))
|
293 |
+
mask = annToMask(ann[k], h, w)
|
294 |
+
masks.append(mask)
|
295 |
+
masks = np.stack(masks, axis=0)
|
296 |
+
masks = torch.from_numpy(masks)
|
297 |
+
else:
|
298 |
+
masks = None
|
299 |
+
|
300 |
+
if isinstance(video_file, list) and len(video_file) == 1:
|
301 |
+
video_file = os.path.join(data_folder, video_file[0])
|
302 |
+
images, timestamps, additional_frames = load_video(video_file, fps=self.data_args.fps, max_frames=self.data_args.max_frames, frame_ids=frame_ids)
|
303 |
+
elif isinstance(video_file, list) and len(video_file)>1: #images
|
304 |
+
images = []
|
305 |
+
for vf in video_file:
|
306 |
+
images+=load_images(os.path.join(data_folder, vf))
|
307 |
+
timestamps = data_dict['timestamps']
|
308 |
+
additional_frames = []
|
309 |
+
for mv in data_dict['masked_video']:
|
310 |
+
additional_frames+=load_images(os.path.join(data_folder, mv))
|
311 |
+
else:
|
312 |
+
raise ValueError(f"Unsupported video format: {video_file}")
|
313 |
+
else:
|
314 |
+
modal = 'text'
|
315 |
+
images = []
|
316 |
+
masks = None
|
317 |
+
|
318 |
+
if masks is not None and len(masks)>0:
|
319 |
+
additional_frames, masks, mask_nums = resize_image_mask(additional_frames, masks, mask_ids)
|
320 |
+
conv_i = 0
|
321 |
+
for idx in range(len(mask_nums)):
|
322 |
+
while '<region>' not in conversation[conv_i]['value']:
|
323 |
+
conv_i+=1
|
324 |
+
conversation[conv_i]['value'] = conversation[conv_i]['value'].replace('<region>', "["+REGION_TOKEN*mask_nums[idx]+"]", 1)
|
325 |
+
|
326 |
+
|
327 |
+
messages = []
|
328 |
+
for conv in conversation:
|
329 |
+
if conv["from"] == "human":
|
330 |
+
# replace video tag to image tag for unified processing
|
331 |
+
# conv["value"] = conv["value"].replace("<video>", "<image>" * len(images))
|
332 |
+
chunks = conv["value"].split("<image>" if modal == 'image' else "<video>")
|
333 |
+
messages.append({
|
334 |
+
"role": "user",
|
335 |
+
"content": []
|
336 |
+
})
|
337 |
+
|
338 |
+
for chunk_idx in range(1, 2 * len(chunks)):
|
339 |
+
if chunk_idx % 2 == 1:
|
340 |
+
chunk = chunks[chunk_idx // 2].strip()
|
341 |
+
messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None
|
342 |
+
else:
|
343 |
+
if modal == 'image':
|
344 |
+
messages[-1]["content"].append({"type": "image"})
|
345 |
+
elif modal == 'video':
|
346 |
+
messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps})
|
347 |
+
else:
|
348 |
+
messages.append({
|
349 |
+
"role": "assistant",
|
350 |
+
"content": conv['value']
|
351 |
+
})
|
352 |
+
|
353 |
+
# TODO: dynamic downsampling
|
354 |
+
# image_downsampling = self.data_args.spatial_merge_size
|
355 |
+
image_downsampling = self.data_args.spatial_merge_size if modal == "video" else 1
|
356 |
+
# if modal == 'video':
|
357 |
+
# image_downsampling = 2
|
358 |
+
# else:
|
359 |
+
# # image/text
|
360 |
+
# image_downsampling = 1
|
361 |
+
|
362 |
+
return modal, images, messages, image_downsampling, masks, additional_frames
|
363 |
+
|
364 |
+
def _convert_stream(self, data_dict):
|
365 |
+
video_path = os.path.join(self.data_args.data_folder, data_dict['video'][0])
|
366 |
+
frames, timestamps = load_video(
|
367 |
+
video_path=video_path,
|
368 |
+
start_time=data_dict["start_time"],
|
369 |
+
end_time=data_dict["end_time"],
|
370 |
+
fps=self.data_args.fps,
|
371 |
+
max_frames=self.data_args.max_frames,
|
372 |
+
size=STREAM_IMAGE_SIZE,
|
373 |
+
# size_divisible=14 * STREAM_DOWNSAMPLING,
|
374 |
+
)
|
375 |
+
|
376 |
+
if len(frames) > STREAM_MAX_FRAMES:
|
377 |
+
max_time = timestamps[STREAM_MAX_FRAMES]
|
378 |
+
frames = frames[:STREAM_MAX_FRAMES]
|
379 |
+
timestamps = timestamps[:STREAM_MAX_FRAMES]
|
380 |
+
else:
|
381 |
+
max_time = float("inf")
|
382 |
+
|
383 |
+
messages = []
|
384 |
+
frame_idx = 0
|
385 |
+
|
386 |
+
conversation = copy.deepcopy(data_dict["conversation"])
|
387 |
+
for message in conversation:
|
388 |
+
if message["time"] >= max_time:
|
389 |
+
break
|
390 |
+
|
391 |
+
while frame_idx < len(timestamps) and timestamps[frame_idx] <= message["time"]:
|
392 |
+
messages.append({
|
393 |
+
"role": "stream",
|
394 |
+
"content": [{"type": "image", "time": timestamps[frame_idx] - data_dict["start_time"]}],
|
395 |
+
})
|
396 |
+
frame_idx += 1
|
397 |
+
|
398 |
+
messages.append(message)
|
399 |
+
|
400 |
+
frames = frames[:frame_idx]
|
401 |
+
|
402 |
+
# return "video", frames, messages, STREAM_DOWNSAMPLING
|
403 |
+
return "video", frames, messages, self.data_args.spatial_merge_size
|
404 |
+
|
405 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
406 |
+
data_dict = self.list_data_dict[i]
|
407 |
+
|
408 |
+
try:
|
409 |
+
if "stream" in data_dict and data_dict["stream"]:
|
410 |
+
modal, images, messages, image_downsampling = self._convert_stream(data_dict)
|
411 |
+
else:
|
412 |
+
modal, images, messages, image_downsampling, masks, additional_frames = self._convert_normal(data_dict)
|
413 |
+
|
414 |
+
data_dict = self.vlprocessor(
|
415 |
+
images=images,
|
416 |
+
text=messages,
|
417 |
+
image_downsampling=image_downsampling,
|
418 |
+
return_labels=True,
|
419 |
+
return_tensors="pt",
|
420 |
+
)
|
421 |
+
if len(additional_frames)>0:
|
422 |
+
additional_images_dict = self.vlprocessor._process_image(additional_frames, num_images=1, image_downsampling=1)
|
423 |
+
additional_images = additional_images_dict['images']
|
424 |
+
additional_images_thws = additional_images_dict['grid_thws']
|
425 |
+
else:
|
426 |
+
additional_images = []
|
427 |
+
additional_images_thws = []
|
428 |
+
|
429 |
+
if modal == 'text':
|
430 |
+
unit_size = self.vlprocessor.image_processor.patch_size**2 * 3 * self.vlprocessor.image_processor.temporal_patch_size
|
431 |
+
data_dict['images'] = [torch.zeros(self.data_args.spatial_merge_size**2, unit_size)]
|
432 |
+
data_dict['grid_thws'] = [torch.tensor([[1, self.data_args.spatial_merge_size, self.data_args.spatial_merge_size]])]
|
433 |
+
elif modal == 'image' or modal == 'video':
|
434 |
+
assert len(data_dict['images']) > 0 and len(data_dict['grid_thws']) > 0, f"Invalid image data: {data_dict['images']}, {data_dict['grid_thws']}"
|
435 |
+
|
436 |
+
data_dict['modal'] = modal
|
437 |
+
data_dict['masks'] = masks
|
438 |
+
data_dict['additional_images'] = additional_images
|
439 |
+
data_dict['additional_images_thws'] = additional_images_thws
|
440 |
+
|
441 |
+
except Exception as e:
|
442 |
+
traceback.print_exc()
|
443 |
+
backup_idx = random.randint(0, len(self.list_data_dict) - 1)
|
444 |
+
print(f"Encounted error when process {i}-th example: {data_dict}, use {backup_idx}-th example instead!!!")
|
445 |
+
return self.__getitem__(backup_idx)
|
446 |
+
|
447 |
+
return data_dict
|
448 |
+
|
449 |
+
|
450 |
+
@dataclass
|
451 |
+
class DataCollatorForSupervisedDataset(object):
|
452 |
+
"""Collate examples for supervised fine-tuning."""
|
453 |
+
|
454 |
+
vlprocessor: transformers.ProcessorMixin
|
455 |
+
|
456 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
457 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
458 |
+
for key in ("input_ids", "labels"))
|
459 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
460 |
+
input_ids,
|
461 |
+
batch_first=True,
|
462 |
+
padding_value=self.vlprocessor.tokenizer.pad_token_id)
|
463 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
464 |
+
batch_first=True,
|
465 |
+
padding_value=IGNORE_INDEX)
|
466 |
+
input_ids = input_ids[:, :self.vlprocessor.tokenizer.model_max_length]
|
467 |
+
labels = labels[:, :self.vlprocessor.tokenizer.model_max_length]
|
468 |
+
batch = dict(
|
469 |
+
input_ids=input_ids,
|
470 |
+
labels=labels,
|
471 |
+
attention_mask=input_ids.ne(self.vlprocessor.tokenizer.pad_token_id),
|
472 |
+
)
|
473 |
+
|
474 |
+
# work for 'images' argument in `prepare_inputs_labels_for_multimodal`
|
475 |
+
batch['images'] = []
|
476 |
+
batch['additional_images'] = []
|
477 |
+
batch["masks"] = []
|
478 |
+
mask_idx_start = 0
|
479 |
+
for instance in instances:
|
480 |
+
# for modal_token in MODAL_INDEX_MAP.keys():
|
481 |
+
# modal_token = modal_token.lower()
|
482 |
+
# # MODAL_TOKEN shape like: <image>, <video>, ...
|
483 |
+
# modal_name = re.findall(f'[<](.*)[>]', modal_token)
|
484 |
+
# assert len(modal_name) == 1
|
485 |
+
# modal_name = modal_name[0]
|
486 |
+
batch['images'].append((instance['modal'], instance['images'], instance['grid_thws']))
|
487 |
+
if len(instance['additional_images'])>0:
|
488 |
+
batch['additional_images'].append((instance['additional_images'], instance['additional_images_thws']))
|
489 |
+
if instance["masks"] is not None:
|
490 |
+
batch["masks"].append(instance["masks"])
|
491 |
+
mask_idx_start+=len(instance['additional_images'])
|
492 |
+
return batch
|
493 |
+
|
494 |
+
|
495 |
+
def make_supervised_data_module(vlprocessor, data_args) -> Dict:
|
496 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
497 |
+
train_dataset = LazySupervisedDataset(
|
498 |
+
vlprocessor=vlprocessor,
|
499 |
+
data_path=data_args.data_path,
|
500 |
+
data_args=data_args
|
501 |
+
)
|
502 |
+
data_collator = DataCollatorForSupervisedDataset(vlprocessor=vlprocessor)
|
503 |
+
return dict(train_dataset=train_dataset,
|
504 |
+
eval_dataset=None,
|
505 |
+
data_collator=data_collator)
|
506 |
+
|
507 |
+
|
508 |
+
@dataclass
|
509 |
+
class DataCollatorWithFlatteningForSupervisedDataset(object):
|
510 |
+
"""Collate examples for batch flattened supervised fine-tuning."""
|
511 |
+
|
512 |
+
vlprocessor: transformers.ProcessorMixin
|
513 |
+
|
514 |
+
def __call__(self, instances: Sequence[Dict], separator_id=-100) -> Dict[str, torch.Tensor]:
|
515 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
516 |
+
for key in ("input_ids", "labels"))
|
517 |
+
|
518 |
+
new_input_ids = []
|
519 |
+
new_labels = []
|
520 |
+
position_ids = []
|
521 |
+
for idx in range(0, len(input_ids)):
|
522 |
+
new_input_ids.append(input_ids[idx][:self.vlprocessor.tokenizer.model_max_length])
|
523 |
+
temp_label = labels[idx][:self.vlprocessor.tokenizer.model_max_length]
|
524 |
+
temp_label[0] = separator_id
|
525 |
+
new_labels.append(temp_label)
|
526 |
+
position_ids.append(torch.tensor(list(range(len(input_ids[idx][:self.vlprocessor.tokenizer.model_max_length])))))
|
527 |
+
|
528 |
+
new_input_ids = torch.cat(new_input_ids)
|
529 |
+
new_labels = torch.cat(new_labels)
|
530 |
+
position_ids = torch.cat(position_ids)
|
531 |
+
|
532 |
+
batch = dict(
|
533 |
+
input_ids=new_input_ids.unsqueeze(0),
|
534 |
+
labels=new_labels.unsqueeze(0),
|
535 |
+
position_ids=position_ids.unsqueeze(0),
|
536 |
+
)
|
537 |
+
|
538 |
+
# work for 'images' argument in `prepare_inputs_labels_for_multimodal`
|
539 |
+
batch['images'] = []
|
540 |
+
batch['additional_images'] = []
|
541 |
+
# mask_idx_start = 0
|
542 |
+
for instance in instances:
|
543 |
+
batch['images'].append((instance['modal'], instance['images'], instance['grid_thws']))
|
544 |
+
if len(instance['additional_images'])>0:
|
545 |
+
batch['additional_images'].append((instance['additional_images'], instance['additional_images_thws']))
|
546 |
+
# mask_idx_start+=len(instance['additional_images'])
|
547 |
+
batch["masks"] = [x["masks"] for x in instances]
|
548 |
+
return batch
|
549 |
+
|
550 |
+
|
551 |
+
def make_flattening_supervised_data_module(vlprocessor: transformers.ProcessorMixin, data_args) -> Dict:
|
552 |
+
"""Make batch flattened dataset and collator for supervised fine-tuning."""
|
553 |
+
train_dataset = LazySupervisedDataset(
|
554 |
+
vlprocessor=vlprocessor,
|
555 |
+
data_path=data_args.data_path,
|
556 |
+
data_args=data_args
|
557 |
+
)
|
558 |
+
data_collator = DataCollatorWithFlatteningForSupervisedDataset(vlprocessor=vlprocessor)
|
559 |
+
return dict(train_dataset=train_dataset,
|
560 |
+
eval_dataset=None,
|
561 |
+
data_collator=data_collator)
|
562 |
+
|
563 |
+
|
564 |
+
def train(attn_implementation=None):
|
565 |
+
global local_rank
|
566 |
+
set_seed(42)
|
567 |
+
|
568 |
+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
569 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
570 |
+
|
571 |
+
local_rank = training_args.local_rank
|
572 |
+
|
573 |
+
if local_rank == 0:
|
574 |
+
print('------model args------')
|
575 |
+
print(model_args)
|
576 |
+
print('------data args------')
|
577 |
+
print(data_args)
|
578 |
+
print('------training args------')
|
579 |
+
print(training_args)
|
580 |
+
|
581 |
+
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
582 |
+
|
583 |
+
bnb_model_from_pretrained_args = {}
|
584 |
+
if training_args.bits in [4, 8]:
|
585 |
+
from transformers import BitsAndBytesConfig
|
586 |
+
bnb_model_from_pretrained_args.update(dict(
|
587 |
+
# device_map={"": training_args.device},
|
588 |
+
# BUG: High version transformers report error:
|
589 |
+
# ValueError: You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing `quantization_config` argument at the same time
|
590 |
+
# load_in_4bit=training_args.bits == 4,
|
591 |
+
# load_in_8bit=training_args.bits == 8,
|
592 |
+
quantization_config=BitsAndBytesConfig(
|
593 |
+
load_in_4bit=training_args.bits == 4,
|
594 |
+
load_in_8bit=training_args.bits == 8,
|
595 |
+
llm_int8_skip_modules=["mm_projector"],
|
596 |
+
llm_int8_threshold=6.0,
|
597 |
+
llm_int8_has_fp16_weight=False,
|
598 |
+
bnb_4bit_compute_dtype=compute_dtype,
|
599 |
+
bnb_4bit_use_double_quant=training_args.double_quant,
|
600 |
+
bnb_4bit_quant_type=training_args.quant_type, # {'fp4', 'nf4'}
|
601 |
+
bnb_4bit_quant_storage=compute_dtype,
|
602 |
+
)
|
603 |
+
))
|
604 |
+
|
605 |
+
config = VLLMConfigs[model_args.model_type].from_pretrained(model_args.model_path)
|
606 |
+
|
607 |
+
config._attn_implementation = attn_implementation
|
608 |
+
# NOTE: active spatial_merge_size arguments
|
609 |
+
config.spatial_merge_size = model_args.spatial_merge_size
|
610 |
+
config.mm_max_length = model_args.mm_max_length
|
611 |
+
config.use_token_compression = model_args.use_token_compression
|
612 |
+
|
613 |
+
if model_args.vision_encoder is not None:
|
614 |
+
model = VLLMs[model_args.model_type].from_pretrained(
|
615 |
+
model_args.model_path,
|
616 |
+
config=config,
|
617 |
+
torch_dtype=compute_dtype,
|
618 |
+
do_sample=True,
|
619 |
+
**bnb_model_from_pretrained_args
|
620 |
+
)
|
621 |
+
if 'mixtral' in model_args.model_type:
|
622 |
+
import deepspeed
|
623 |
+
deepspeed.utils.set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
624 |
+
else:
|
625 |
+
model = transformers.LlamaForCausalLM.from_pretrained(
|
626 |
+
model_args.model_path,
|
627 |
+
config=config,
|
628 |
+
torch_dtype=compute_dtype,
|
629 |
+
do_sample=True,
|
630 |
+
**bnb_model_from_pretrained_args
|
631 |
+
)
|
632 |
+
model.config.use_cache = False
|
633 |
+
if model_args.freeze_backbone:
|
634 |
+
model.model.requires_grad_(False)
|
635 |
+
|
636 |
+
if training_args.bits in [4, 8]:
|
637 |
+
from peft import prepare_model_for_kbit_training
|
638 |
+
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
|
639 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
|
640 |
+
|
641 |
+
if training_args.gradient_checkpointing:
|
642 |
+
if hasattr(model, "enable_input_require_grads"):
|
643 |
+
model.enable_input_require_grads()
|
644 |
+
else:
|
645 |
+
def make_inputs_require_grad(module, input, output):
|
646 |
+
output.requires_grad_(True)
|
647 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
648 |
+
|
649 |
+
if training_args.lora_enable:
|
650 |
+
from peft import LoraConfig, get_peft_model
|
651 |
+
lora_config = LoraConfig(
|
652 |
+
r=training_args.lora_r,
|
653 |
+
lora_alpha=training_args.lora_alpha,
|
654 |
+
target_modules=find_all_linear_names(model),
|
655 |
+
lora_dropout=training_args.lora_dropout,
|
656 |
+
bias=training_args.lora_bias,
|
657 |
+
task_type="CAUSAL_LM",
|
658 |
+
)
|
659 |
+
if training_args.bits == 16:
|
660 |
+
if training_args.bf16:
|
661 |
+
model.to(torch.bfloat16)
|
662 |
+
if training_args.fp16:
|
663 |
+
model.to(torch.float16)
|
664 |
+
rank0_print("Adding LoRA adapters...")
|
665 |
+
model = get_peft_model(model, lora_config)
|
666 |
+
|
667 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
668 |
+
model_args.model_path,
|
669 |
+
model_max_length=training_args.model_max_length,
|
670 |
+
padding_side="right",
|
671 |
+
use_fast=True,
|
672 |
+
)
|
673 |
+
|
674 |
+
if tokenizer.pad_token is None:
|
675 |
+
tokenizer.pad_token = tokenizer.unk_token
|
676 |
+
|
677 |
+
if model_args.vision_encoder is not None:
|
678 |
+
# initialize vision encoder + multi-modal projector
|
679 |
+
model.get_model().initialize_vision_modules(model_args=model_args, fsdp=training_args.fsdp)
|
680 |
+
|
681 |
+
vision_encoder = model.get_vision_encoder()
|
682 |
+
vision_encoder.to(dtype=compute_dtype, device=training_args.device)
|
683 |
+
|
684 |
+
mm_projector = model.get_mm_projector()
|
685 |
+
mm_projector.to(dtype=compute_dtype if training_args.bf16 else torch.float16, device=training_args.device)
|
686 |
+
|
687 |
+
data_args.is_multimodal = True
|
688 |
+
|
689 |
+
model.config.tokenizer_padding_side = tokenizer.padding_side
|
690 |
+
model.config.tokenizer_model_max_length = tokenizer.model_max_length
|
691 |
+
|
692 |
+
if training_args.bits in [4, 8]:
|
693 |
+
model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
|
694 |
+
|
695 |
+
# decoupled learning rate
|
696 |
+
model.config.llm_lr = training_args.llm_lr
|
697 |
+
model.config.vision_encoder_lr = training_args.vision_encoder_lr
|
698 |
+
model.config.mm_projector_lr = training_args.mm_projector_lr
|
699 |
+
model.config.region_encoder_lr = training_args.region_encoder_lr
|
700 |
+
|
701 |
+
if model.config.llm_lr is None:
|
702 |
+
for p in model.get_model().parameters():
|
703 |
+
p.requires_grad = False
|
704 |
+
for p in model.get_model().vision_encoder.parameters():
|
705 |
+
p.requires_grad = True
|
706 |
+
for p in model.get_model().mm_projector.parameters():
|
707 |
+
p.requires_grad = True
|
708 |
+
for p in model.get_model().region_encoder.parameters():
|
709 |
+
p.requires_grad = True
|
710 |
+
|
711 |
+
|
712 |
+
if model.config.vision_encoder_lr is None:
|
713 |
+
for p in model.get_model().vision_encoder.parameters():
|
714 |
+
p.requires_grad = False
|
715 |
+
|
716 |
+
if model.config.mm_projector_lr is None:
|
717 |
+
for p in model.get_model().mm_projector.parameters():
|
718 |
+
p.requires_grad = False
|
719 |
+
|
720 |
+
if model.config.region_encoder_lr is None:
|
721 |
+
for p in model.get_model().region_encoder.parameters():
|
722 |
+
p.requires_grad = False
|
723 |
+
|
724 |
+
model.config.max_frames = getattr(data_args, 'max_frames', NUM_FRAMES)
|
725 |
+
model.config.image_aspect_ratio = data_args.image_aspect_ratio if 'qwen2vl' not in model_args.vision_encoder else 'qwen2vl'
|
726 |
+
|
727 |
+
# NOTE: complement data_args via model hyperparameters
|
728 |
+
# 1. acquire image size
|
729 |
+
model.config.image_size = data_args.image_size = vision_encoder.image_size
|
730 |
+
# 2. calculate the number of tokens in the image
|
731 |
+
model.config.image_token_length = data_args.image_token_length = mm_projector.cal_proj_size(vision_encoder.num_patches_per_side)
|
732 |
+
# 3. check if alignment
|
733 |
+
model.config.is_alignment = training_args.is_alignment = data_args.is_alignment = (
|
734 |
+
model.config.mm_projector_lr is not None and
|
735 |
+
model.config.llm_lr is None and
|
736 |
+
model.config.vision_encoder_lr is None
|
737 |
+
)
|
738 |
+
# 4. set spatial merge size as default
|
739 |
+
model.config.spatial_merge_size = data_args.spatial_merge_size = model_args.spatial_merge_size
|
740 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_TOKEN, STREAM_START_TOKEN, STREAM_END_TOKEN], special_tokens=True)
|
741 |
+
tokenizer.add_tokens([REGION_TOKEN], special_tokens=True)
|
742 |
+
model.resize_token_embeddings(len(tokenizer))
|
743 |
+
|
744 |
+
model.config.image_token_index = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
|
745 |
+
model.config.region_token_index = tokenizer.convert_tokens_to_ids(REGION_TOKEN)
|
746 |
+
|
747 |
+
|
748 |
+
vlprocessor = Videollama3Processor(vision_encoder.image_processor, tokenizer)
|
749 |
+
|
750 |
+
if training_args.bits in [4, 8]:
|
751 |
+
from peft.tuners.lora import LoraLayer
|
752 |
+
for name, module in model.named_modules():
|
753 |
+
if isinstance(module, LoraLayer):
|
754 |
+
if training_args.bf16:
|
755 |
+
module = module.to(torch.bfloat16)
|
756 |
+
if 'norm' in name:
|
757 |
+
module = module.to(torch.float32)
|
758 |
+
if 'lm_head' in name or 'embed_tokens' in name:
|
759 |
+
if hasattr(module, 'weight'):
|
760 |
+
if training_args.bf16 and module.weight.dtype == torch.float32:
|
761 |
+
module = module.to(torch.bfloat16)
|
762 |
+
|
763 |
+
if local_rank == 0:
|
764 |
+
print("Current model:", model)
|
765 |
+
print("Model config:", model.config)
|
766 |
+
|
767 |
+
if data_args.use_batch_flattening:
|
768 |
+
rank0_print('You are using flattening operation to flatten the entire mini batch into a single sequence')
|
769 |
+
assert model.config._attn_implementation == 'flash_attention_2'
|
770 |
+
assert version.parse(transformers.__version__) >= version.parse("4.44.0")
|
771 |
+
data_module = make_flattening_supervised_data_module(vlprocessor=vlprocessor, data_args=data_args)
|
772 |
+
else:
|
773 |
+
data_module = make_supervised_data_module(vlprocessor=vlprocessor, data_args=data_args)
|
774 |
+
|
775 |
+
# select a Trainer
|
776 |
+
trainer = VideoLLaMA3Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
|
777 |
+
|
778 |
+
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
779 |
+
trainer.train(resume_from_checkpoint=True)
|
780 |
+
else:
|
781 |
+
trainer.train()
|
782 |
+
trainer.save_state()
|
783 |
+
|
784 |
+
model.config.use_cache = True
|
785 |
+
|
786 |
+
if training_args.lora_enable:
|
787 |
+
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(), training_args.lora_bias)
|
788 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
|
789 |
+
if training_args.local_rank == 0 or training_args.local_rank == -1:
|
790 |
+
model.config.save_pretrained(training_args.output_dir)
|
791 |
+
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
|
792 |
+
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
|
793 |
+
else:
|
794 |
+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
|
795 |
+
|
796 |
+
|
797 |
+
if __name__ == "__main__":
|
798 |
+
train(attn_implementation="flash_attention_2")
|
videollama3/videollama3_trainer.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted from: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.utils.data import Sampler
|
9 |
+
|
10 |
+
from transformers import Trainer
|
11 |
+
from transformers.trainer import (
|
12 |
+
is_sagemaker_mp_enabled,
|
13 |
+
get_parameter_names,
|
14 |
+
has_length,
|
15 |
+
ALL_LAYERNORM_LAYERS,
|
16 |
+
logger,
|
17 |
+
TRAINER_STATE_NAME,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
22 |
+
from deepspeed import zero
|
23 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
24 |
+
if hasattr(param, "ds_id"):
|
25 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
|
26 |
+
if not ignore_status:
|
27 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
28 |
+
with zero.GatheredParameters([param]):
|
29 |
+
param = param.data.detach().cpu().clone()
|
30 |
+
else:
|
31 |
+
param = param.detach().cpu().clone()
|
32 |
+
return param
|
33 |
+
|
34 |
+
|
35 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
36 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
37 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
|
38 |
+
return to_return
|
39 |
+
|
40 |
+
|
41 |
+
# Borrowed from peft.utils.get_peft_model_state_dict
|
42 |
+
def get_peft_state_maybe_zero_3(named_params, bias):
|
43 |
+
if bias == "none":
|
44 |
+
to_return = {k: t for k, t in named_params if "lora_" in k}
|
45 |
+
elif bias == "all":
|
46 |
+
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
|
47 |
+
elif bias == "lora_only":
|
48 |
+
to_return = {}
|
49 |
+
maybe_lora_bias = {}
|
50 |
+
lora_bias_names = set()
|
51 |
+
for k, t in named_params:
|
52 |
+
if "lora_" in k:
|
53 |
+
to_return[k] = t
|
54 |
+
bias_name = k.split("lora_")[0] + "bias"
|
55 |
+
lora_bias_names.add(bias_name)
|
56 |
+
elif "bias" in k:
|
57 |
+
maybe_lora_bias[k] = t
|
58 |
+
for k, t in maybe_lora_bias:
|
59 |
+
if bias_name in lora_bias_names:
|
60 |
+
to_return[bias_name] = t
|
61 |
+
else:
|
62 |
+
raise NotImplementedError
|
63 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
|
64 |
+
return to_return
|
65 |
+
|
66 |
+
|
67 |
+
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
|
68 |
+
to_return = {k: t for k, t in named_params if "lora_" not in k}
|
69 |
+
if require_grad_only:
|
70 |
+
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
|
71 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
72 |
+
return to_return
|
73 |
+
|
74 |
+
|
75 |
+
def find_all_linear_names(model):
|
76 |
+
cls = torch.nn.Linear
|
77 |
+
lora_module_names = set()
|
78 |
+
multimodal_keywords = ['mm_projector', 'vision_encoder', 'vision_resampler']
|
79 |
+
for name, module in model.named_modules():
|
80 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
81 |
+
continue
|
82 |
+
if isinstance(module, cls):
|
83 |
+
names = name.split('.')
|
84 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
85 |
+
|
86 |
+
if 'lm_head' in lora_module_names: # needed for 16-bit
|
87 |
+
lora_module_names.remove('lm_head')
|
88 |
+
return list(lora_module_names)
|
89 |
+
|
90 |
+
|
91 |
+
def safe_save_model_for_hf_trainer(trainer: Trainer,
|
92 |
+
output_dir: str):
|
93 |
+
"""Collects the state dict and dump to disk."""
|
94 |
+
|
95 |
+
if getattr(trainer.args, "is_alignment", False):
|
96 |
+
# Only save Adapter
|
97 |
+
keys_to_match = ['mm_projector']
|
98 |
+
|
99 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
|
100 |
+
trainer.model.config.save_pretrained(output_dir)
|
101 |
+
|
102 |
+
current_folder = output_dir.split('/')[-1]
|
103 |
+
parent_folder = os.path.dirname(output_dir)
|
104 |
+
# if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
|
105 |
+
if torch.distributed.get_rank() == 0:
|
106 |
+
if current_folder.startswith('checkpoint-'):
|
107 |
+
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
|
108 |
+
os.makedirs(mm_projector_folder, exist_ok=True)
|
109 |
+
torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
|
110 |
+
else:
|
111 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
112 |
+
return
|
113 |
+
|
114 |
+
if trainer.deepspeed:
|
115 |
+
torch.cuda.synchronize()
|
116 |
+
trainer.save_model(output_dir)
|
117 |
+
return
|
118 |
+
|
119 |
+
state_dict = trainer.model.state_dict()
|
120 |
+
if trainer.args.should_save:
|
121 |
+
cpu_state_dict = {
|
122 |
+
key: value.cpu()
|
123 |
+
for key, value in state_dict.items()
|
124 |
+
}
|
125 |
+
del state_dict
|
126 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
127 |
+
|
128 |
+
|
129 |
+
def split_to_even_chunks(indices, lengths, num_chunks):
|
130 |
+
"""
|
131 |
+
Split a list of indices into `chunks` chunks of roughly equal lengths.
|
132 |
+
"""
|
133 |
+
|
134 |
+
if len(indices) % num_chunks != 0:
|
135 |
+
return [indices[i::num_chunks] for i in range(num_chunks)]
|
136 |
+
|
137 |
+
num_indices_per_chunk = len(indices) // num_chunks
|
138 |
+
|
139 |
+
chunks = [[] for _ in range(num_chunks)]
|
140 |
+
chunks_lengths = [0 for _ in range(num_chunks)]
|
141 |
+
for index in indices:
|
142 |
+
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
|
143 |
+
chunks[shortest_chunk].append(index)
|
144 |
+
chunks_lengths[shortest_chunk] += lengths[index]
|
145 |
+
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
|
146 |
+
chunks_lengths[shortest_chunk] = float("inf")
|
147 |
+
|
148 |
+
return chunks
|
149 |
+
|
150 |
+
|
151 |
+
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
|
152 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
153 |
+
assert all(l != 0 for l in lengths), "Should not have zero length."
|
154 |
+
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
|
155 |
+
# all samples are in the same modality
|
156 |
+
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
|
157 |
+
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
|
158 |
+
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
|
159 |
+
|
160 |
+
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
|
161 |
+
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
|
162 |
+
megabatch_size = world_size * batch_size
|
163 |
+
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
|
164 |
+
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
|
165 |
+
|
166 |
+
last_mm = mm_megabatches[-1]
|
167 |
+
last_lang = lang_megabatches[-1]
|
168 |
+
additional_batch = last_mm + last_lang
|
169 |
+
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
|
170 |
+
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
|
171 |
+
megabatches = [megabatches[i] for i in megabatch_indices]
|
172 |
+
|
173 |
+
if len(additional_batch) > 0:
|
174 |
+
megabatches.append(sorted(additional_batch))
|
175 |
+
|
176 |
+
return [i for megabatch in megabatches for i in megabatch]
|
177 |
+
|
178 |
+
|
179 |
+
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
|
180 |
+
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
181 |
+
indices = torch.randperm(len(lengths), generator=generator)
|
182 |
+
megabatch_size = world_size * batch_size
|
183 |
+
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
184 |
+
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
|
185 |
+
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
|
186 |
+
|
187 |
+
return [i for megabatch in megabatches for batch in megabatch for i in batch]
|
188 |
+
|
189 |
+
|
190 |
+
class LengthGroupedSampler(Sampler):
|
191 |
+
r"""
|
192 |
+
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
193 |
+
keeping a bit of randomness.
|
194 |
+
"""
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
batch_size: int,
|
199 |
+
world_size: int,
|
200 |
+
lengths: Optional[List[int]] = None,
|
201 |
+
generator=None,
|
202 |
+
group_by_modality: bool = False,
|
203 |
+
):
|
204 |
+
if lengths is None:
|
205 |
+
raise ValueError("Lengths must be provided.")
|
206 |
+
|
207 |
+
self.batch_size = batch_size
|
208 |
+
self.world_size = world_size
|
209 |
+
self.lengths = lengths
|
210 |
+
self.generator = generator
|
211 |
+
self.group_by_modality = group_by_modality
|
212 |
+
|
213 |
+
def __len__(self):
|
214 |
+
return len(self.lengths)
|
215 |
+
|
216 |
+
def __iter__(self):
|
217 |
+
if self.group_by_modality:
|
218 |
+
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
219 |
+
else:
|
220 |
+
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
|
221 |
+
return iter(indices)
|
222 |
+
|
223 |
+
|
224 |
+
class VideoLLaMA3Trainer(Trainer):
|
225 |
+
|
226 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
227 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
228 |
+
return None
|
229 |
+
|
230 |
+
if self.args.group_by_modality_length:
|
231 |
+
lengths = self.train_dataset.modality_lengths
|
232 |
+
return LengthGroupedSampler(
|
233 |
+
self.args.train_batch_size,
|
234 |
+
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
|
235 |
+
lengths=lengths,
|
236 |
+
group_by_modality=True,
|
237 |
+
)
|
238 |
+
else:
|
239 |
+
return super()._get_train_sampler()
|
240 |
+
|
241 |
+
def create_optimizer(self):
|
242 |
+
"""
|
243 |
+
Setup the optimizer.
|
244 |
+
|
245 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
246 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
247 |
+
"""
|
248 |
+
if is_sagemaker_mp_enabled():
|
249 |
+
return super().create_optimizer()
|
250 |
+
|
251 |
+
opt_model = self.model
|
252 |
+
|
253 |
+
if self.optimizer is None:
|
254 |
+
optimized_parameters = [(n, p) for n, p in opt_model.named_parameters() if p.requires_grad]
|
255 |
+
optimizer_grouped_parameters = []
|
256 |
+
|
257 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
258 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
259 |
+
|
260 |
+
if self.args.llm_lr is not None:
|
261 |
+
lm_parameters = [
|
262 |
+
name for name, _ in optimized_parameters if "vision_encoder" not in name and "mm_projector" not in name and "region_encoder" not in name
|
263 |
+
]
|
264 |
+
decay_lm_parameters = [name for name in lm_parameters if name in decay_parameters]
|
265 |
+
nodecay_lm_parameters = [name for name in lm_parameters if name not in decay_parameters]
|
266 |
+
optimizer_grouped_parameters.extend([
|
267 |
+
{
|
268 |
+
"params": [p for n, p in optimized_parameters if n in decay_lm_parameters],
|
269 |
+
"weight_decay": self.args.weight_decay,
|
270 |
+
"lr": self.args.llm_lr,
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_lm_parameters],
|
274 |
+
"weight_decay": 0.0,
|
275 |
+
"lr": self.args.llm_lr,
|
276 |
+
}
|
277 |
+
])
|
278 |
+
|
279 |
+
if self.args.mm_projector_lr is not None:
|
280 |
+
projector_parameters = [name for name, _ in optimized_parameters if "mm_projector" in name]
|
281 |
+
decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
|
282 |
+
nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
|
283 |
+
optimizer_grouped_parameters.extend([
|
284 |
+
{
|
285 |
+
"params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
|
286 |
+
"weight_decay": self.args.weight_decay,
|
287 |
+
"lr": self.args.mm_projector_lr,
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
|
291 |
+
"weight_decay": 0.0,
|
292 |
+
"lr": self.args.mm_projector_lr,
|
293 |
+
}
|
294 |
+
])
|
295 |
+
|
296 |
+
if self.args.vision_encoder_lr is not None:
|
297 |
+
vision_encoder_parameters = [name for name, _ in optimized_parameters if "vision_encoder" in name]
|
298 |
+
decay_vision_encoder_parameters = [name for name in vision_encoder_parameters if name in decay_parameters]
|
299 |
+
nodecay_vision_encoder_parameters = [name for name in vision_encoder_parameters if name not in decay_parameters]
|
300 |
+
optimizer_grouped_parameters.extend([
|
301 |
+
{
|
302 |
+
"params": [p for n, p in optimized_parameters if n in decay_vision_encoder_parameters],
|
303 |
+
"weight_decay": self.args.weight_decay,
|
304 |
+
"lr": self.args.vision_encoder_lr,
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_vision_encoder_parameters],
|
308 |
+
"weight_decay": 0.0,
|
309 |
+
"lr": self.args.vision_encoder_lr,
|
310 |
+
}
|
311 |
+
])
|
312 |
+
|
313 |
+
if self.args.region_encoder_lr is not None:
|
314 |
+
projector_parameters = [name for name, _ in optimized_parameters if "region_encoder" in name]
|
315 |
+
decay_projector_parameters = [name for name in projector_parameters if name in decay_parameters]
|
316 |
+
nodecay_projector_parameters = [name for name in projector_parameters if name not in decay_parameters]
|
317 |
+
optimizer_grouped_parameters.extend([
|
318 |
+
{
|
319 |
+
"params": [p for n, p in optimized_parameters if n in decay_projector_parameters],
|
320 |
+
"weight_decay": self.args.weight_decay,
|
321 |
+
"lr": self.args.region_encoder_lr,
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"params": [p for n, p in optimized_parameters if n in nodecay_projector_parameters],
|
325 |
+
"weight_decay": 0.0,
|
326 |
+
"lr": self.args.region_encoder_lr,
|
327 |
+
}
|
328 |
+
])
|
329 |
+
|
330 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
331 |
+
|
332 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
333 |
+
if optimizer_cls.__name__ == "Adam8bit":
|
334 |
+
import bitsandbytes
|
335 |
+
|
336 |
+
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
337 |
+
|
338 |
+
skipped = 0
|
339 |
+
for module in opt_model.modules():
|
340 |
+
if isinstance(module, nn.Embedding):
|
341 |
+
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
|
342 |
+
logger.info(f"skipped {module}: {skipped/2**20}M params")
|
343 |
+
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
344 |
+
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
|
345 |
+
logger.info(f"skipped: {skipped/2**20}M params")
|
346 |
+
|
347 |
+
return self.optimizer
|
348 |
+
|
349 |
+
def _save_checkpoint(self, model, trial, metrics=None):
|
350 |
+
if getattr(self.args, 'is_alignment', False):
|
351 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
352 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
353 |
+
|
354 |
+
run_dir = self._get_output_dir(trial=trial)
|
355 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
356 |
+
|
357 |
+
# Only save Adapter
|
358 |
+
keys_to_match = ['mm_projector', 'vision_resampler']
|
359 |
+
|
360 |
+
weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
|
361 |
+
|
362 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
363 |
+
self.model.config.save_pretrained(output_dir)
|
364 |
+
torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
|
365 |
+
# Save optimizer and scheduler
|
366 |
+
self._save_optimizer_and_scheduler(output_dir)
|
367 |
+
# Save RNG state
|
368 |
+
self._save_rng_state(output_dir)
|
369 |
+
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
|
370 |
+
self.args.distributed_state.wait_for_everyone()
|
371 |
+
else:
|
372 |
+
# NOTE: Supporting save complete lora checkpoint during training.
|
373 |
+
if self.args.lora_enable:
|
374 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
375 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
376 |
+
|
377 |
+
run_dir = self._get_output_dir(trial=trial)
|
378 |
+
output_dir = os.path.join(run_dir, checkpoint_folder)
|
379 |
+
|
380 |
+
state_dict = get_peft_state_maybe_zero_3(self.model.named_parameters(), self.args.lora_bias)
|
381 |
+
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(self.model.named_parameters())
|
382 |
+
if self.args.local_rank == 0 or self.args.local_rank == -1:
|
383 |
+
# save for acquring `config.json`
|
384 |
+
self.model.config.save_pretrained(output_dir)
|
385 |
+
# save for acquring `adapter_config.json`, `adapter_model.bin`
|
386 |
+
# self.model.save_pretrained(output_dir, state_dict=state_dict)
|
387 |
+
torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
|
388 |
+
|
389 |
+
# save for acquring lora adapter parameters & trainer states: `adapter_config.json`, `adapter_model.safetensors`
|
390 |
+
super(VideoLLaMA3Trainer, self)._save_checkpoint(model, trial, metrics)
|
391 |
+
else:
|
392 |
+
super(VideoLLaMA3Trainer, self)._save_checkpoint(model, trial, metrics)
|
393 |
+
|
394 |
+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
|
395 |
+
if getattr(self.args, 'is_alignment', False):
|
396 |
+
pass
|
397 |
+
else:
|
398 |
+
super(VideoLLaMA3Trainer, self)._save(output_dir, state_dict)
|